PlanPatternGenerator.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.pattern.generator;
import org.apache.doris.nereids.pattern.generator.javaast.ClassDeclaration;
import org.apache.doris.nereids.pattern.generator.javaast.EnumConstant;
import org.apache.doris.nereids.pattern.generator.javaast.EnumDeclaration;
import org.apache.doris.nereids.pattern.generator.javaast.FieldDeclaration;
import org.apache.doris.nereids.pattern.generator.javaast.MethodDeclaration;
import org.apache.doris.nereids.pattern.generator.javaast.VariableDeclarator;
import com.google.common.base.Joiner;
import org.apache.commons.lang3.StringUtils;
import java.util.AbstractMap.SimpleEntry;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
/** used to generate pattern by plan. */
public abstract class PlanPatternGenerator {
protected final JavaAstAnalyzer analyzer;
protected final ClassDeclaration opType;
protected final Set<String> parentClass;
protected final List<EnumFieldPatternInfo> enumFieldPatternInfos;
protected final List<String> generatePatterns = new ArrayList<>();
protected final boolean isMemoPattern;
/** constructor. */
public PlanPatternGenerator(PlanPatternGeneratorAnalyzer analyzer, ClassDeclaration opType,
Set<String> parentClass, boolean isMemoPattern) {
this.analyzer = analyzer.getAnalyzer();
this.opType = opType;
this.parentClass = parentClass;
this.enumFieldPatternInfos = getEnumFieldPatternInfos();
this.isMemoPattern = isMemoPattern;
}
public abstract String genericType();
public abstract String genericTypeWithChildren();
public abstract Set<String> getImports();
public abstract boolean isLogical();
public abstract int childrenNum();
public String getPatternMethodName() {
return opType.name.substring(0, 1).toLowerCase(Locale.ENGLISH) + opType.name.substring(1);
}
/** generate code by generators and analyzer. */
public static String generateCode(String className, String parentClassName, List<PlanPatternGenerator> generators,
PlanPatternGeneratorAnalyzer analyzer, boolean isMemoPattern) {
String generateCode
= "// Licensed to the Apache Software Foundation (ASF) under one\n"
+ "// or more contributor license agreements. See the NOTICE file\n"
+ "// distributed with this work for additional information\n"
+ "// regarding copyright ownership. The ASF licenses this file\n"
+ "// to you under the Apache License, Version 2.0 (the\n"
+ "// \"License\"); you may not use this file except in compliance\n"
+ "// with the License. You may obtain a copy of the License at\n"
+ "//\n"
+ "// http://www.apache.org/licenses/LICENSE-2.0\n"
+ "//\n"
+ "// Unless required by applicable law or agreed to in writing,\n"
+ "// software distributed under the License is distributed on an\n"
+ "// \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n"
+ "// KIND, either express or implied. See the License for the\n"
+ "// specific language governing permissions and limitations\n"
+ "// under the License.\n"
+ "\n"
+ "package org.apache.doris.nereids.pattern;\n"
+ "\n"
+ generateImports(generators)
+ "\n";
generateCode += "public interface " + className + " extends " + parentClassName + " {\n";
generateCode += generators.stream()
.map(generator -> {
String patternMethods = generator.generate(isMemoPattern);
// add indent
return Arrays.stream(patternMethods.split("\n"))
.map(line -> " " + line + "\n")
.collect(Collectors.joining(""));
}).collect(Collectors.joining("\n"));
return generateCode + "}\n";
}
protected List<EnumFieldPatternInfo> getEnumFieldPatternInfos() {
List<EnumFieldPatternInfo> enumFieldInfos = new ArrayList<>();
for (Entry<FieldDeclaration, EnumDeclaration> pair : findEnumFieldType()) {
FieldDeclaration fieldDecl = pair.getKey();
EnumDeclaration enumDecl = pair.getValue();
Set<String> enumClassNameParts = splitCase(enumDecl.name)
.stream()
.map(part -> part.toLowerCase(Locale.ENGLISH))
.collect(Collectors.toSet());
for (VariableDeclarator varDecl : fieldDecl.variableDeclarators.variableDeclarators) {
String enumFieldName = varDecl.variableDeclaratorId.identifier;
Optional<String> getter = findGetter(enumDecl.name, enumFieldName);
if (getter.isPresent()) {
for (EnumConstant constant : enumDecl.constants) {
String enumInstance = constant.identifier;
String enumPatternName = getEnumPatternName(enumInstance, enumClassNameParts) + opType.name;
enumFieldInfos.add(new EnumFieldPatternInfo(enumPatternName,
enumDecl.getFullQualifiedName(), enumDecl.name, enumInstance, getter.get()));
}
}
}
}
return enumFieldInfos;
}
protected Optional<String> findGetter(String type, String name) {
String getterName = "get" + name.substring(0, 1).toUpperCase(Locale.ENGLISH) + name.substring(1);
for (MethodDeclaration methodDecl : opType.methodDeclarations) {
if (methodDecl.typeTypeOrVoid.isVoid) {
continue;
}
if (methodDecl.typeTypeOrVoid.typeType.isPresent()
&& methodDecl.typeTypeOrVoid.typeType.get().toString().equals(type)) {
if (methodDecl.identifier.equals(getterName) && methodDecl.paramNum == 0) {
return Optional.of(getterName);
}
}
}
return Optional.empty();
}
protected String getEnumPatternName(String enumInstance, Set<String> enumClassNameParts) {
String[] instanceNameParts = enumInstance.split("_+");
List<String> newParts = new ArrayList<>();
boolean isFirst = true;
for (int i = 0; i < instanceNameParts.length; i++) {
String part = instanceNameParts[i].toLowerCase(Locale.ENGLISH);
// skip instanceNameParts, e.g. INNER_JOIN has two part: [inner and Join].
// because 'Join' is the part of the 'JoinType' enum className, so skip 'Join' and return 'inner'
if (part.isEmpty() || enumClassNameParts.contains(part)) {
continue;
}
if (!isFirst) {
newParts.add(part.substring(0, 1).toUpperCase(Locale.ENGLISH) + part.substring(1));
} else {
newParts.add(part.substring(0, 1).toLowerCase(Locale.ENGLISH) + part.substring(1));
}
isFirst = false;
}
return Joiner.on("").join(newParts);
}
protected List<Map.Entry<FieldDeclaration, EnumDeclaration>> findEnumFieldType() {
return opType.fieldDeclarations
.stream()
.map(f -> new SimpleEntry<>(f, analyzer.getType(opType, f.type)))
.filter(pair -> pair.getValue().isPresent() && pair.getValue().get() instanceof EnumDeclaration)
.map(pair -> new SimpleEntry<>(pair.getKey(), (EnumDeclaration) (pair.getValue().get())))
.collect(Collectors.toList());
}
// e.g. split PhysicalBroadcastHashJoin to [Physical, Broadcast, Hash, Join]
// e.g. split JoinType to [Join, Type]
protected List<String> splitCase(String name) {
Pattern pattern = Pattern.compile("([A-Z]+[^A-Z]*)");
Matcher matcher = pattern.matcher(name);
List<String> parts = new ArrayList<>();
while (matcher.find()) {
parts.add(matcher.group(0));
}
return parts;
}
protected String childType() {
return isMemoPattern ? "GroupPlan" : "Plan";
}
/** create generator by plan's type. */
public static Optional<PlanPatternGenerator> create(PlanPatternGeneratorAnalyzer analyzer,
ClassDeclaration opType, Set<String> parentClass, boolean isMemoPattern) {
if (parentClass.contains("org.apache.doris.nereids.trees.plans.logical.LogicalLeaf")) {
return Optional.of(new LogicalLeafPatternGenerator(analyzer, opType, parentClass, isMemoPattern));
} else if (parentClass.contains("org.apache.doris.nereids.trees.plans.logical.LogicalUnary")) {
return Optional.of(new LogicalUnaryPatternGenerator(analyzer, opType, parentClass, isMemoPattern));
} else if (parentClass.contains("org.apache.doris.nereids.trees.plans.logical.LogicalBinary")) {
return Optional.of(new LogicalBinaryPatternGenerator(analyzer, opType, parentClass, isMemoPattern));
} else if (parentClass.contains("org.apache.doris.nereids.trees.plans.physical.PhysicalLeaf")) {
return Optional.of(new PhysicalLeafPatternGenerator(analyzer, opType, parentClass, isMemoPattern));
} else if (parentClass.contains("org.apache.doris.nereids.trees.plans.physical.PhysicalUnary")) {
return Optional.of(new PhysicalUnaryPatternGenerator(analyzer, opType, parentClass, isMemoPattern));
} else if (parentClass.contains("org.apache.doris.nereids.trees.plans.physical.PhysicalBinary")) {
return Optional.of(new PhysicalBinaryPatternGenerator(analyzer, opType, parentClass, isMemoPattern));
} else {
return Optional.empty();
}
}
private static String generateImports(List<PlanPatternGenerator> generators) {
Set<String> imports = new HashSet<>();
for (PlanPatternGenerator generator : generators) {
imports.addAll(generator.getImports());
}
List<String> sortedImports = new ArrayList<>(imports);
sortedImports.sort(Comparator.naturalOrder());
return sortedImports.stream()
.map(it -> "import " + it + ";\n")
.collect(Collectors.joining(""));
}
/** generate some pattern method code. */
public String generate(boolean isMemoPattern) {
String opClassName = opType.name;
String methodName = getPatternMethodName();
generateTypePattern(methodName, opClassName, genericType(), "", false, isMemoPattern);
if (childrenNum() > 0) {
generateTypePattern(methodName, opClassName, genericTypeWithChildren(),
"", true, isMemoPattern);
}
for (EnumFieldPatternInfo info : enumFieldPatternInfos) {
String predicate = ".when(p -> p." + info.enumInstanceGetter + "() == "
+ info.enumType + "." + info.enumInstance + ")";
generateTypePattern(info.patternName, opClassName, genericType(),
predicate, false, isMemoPattern);
if (childrenNum() > 0) {
generateTypePattern(info.patternName, opClassName, genericTypeWithChildren(),
predicate, true, isMemoPattern);
}
}
return generatePatterns();
}
/** generate a pattern method code. */
public String generateTypePattern(String patterName, String className,
String genericParam, String predicate, boolean specifyChildren, boolean isMemoPattern) {
int childrenNum = childrenNum();
if (specifyChildren) {
StringBuilder methodGenericBuilder = new StringBuilder("<");
StringBuilder methodParamBuilder = new StringBuilder();
StringBuilder childrenPatternBuilder = new StringBuilder();
int min = Math.min(1, childrenNum);
int max = Math.max(1, childrenNum);
for (int i = min; i <= max; i++) {
methodGenericBuilder.append("C").append(i).append(" extends Plan");
methodParamBuilder.append("PatternDescriptor<C").append(i).append("> child").append(i);
childrenPatternBuilder.append("child").append(i).append(".pattern");
if (i < max) {
methodGenericBuilder.append(", ");
methodParamBuilder.append(", ");
childrenPatternBuilder.append(", ");
}
}
methodGenericBuilder.append(">");
if (childrenNum > 0) {
childrenPatternBuilder.insert(0, ", ");
}
String pattern = "default " + methodGenericBuilder + "\n"
+ "PatternDescriptor" + genericParam + "\n"
+ " " + patterName + "(" + methodParamBuilder + ") {\n"
+ " return new PatternDescriptor" + genericParam + "(\n"
+ " new TypePattern(" + className + ".class" + childrenPatternBuilder + "),\n"
+ " defaultPromise()\n"
+ " )" + predicate + ";\n"
+ "}\n";
generatePatterns.add(pattern);
return pattern;
}
String childrenPattern = StringUtils.repeat(
isMemoPattern ? "Pattern.GROUP" : "Pattern.ANY", ", ", childrenNum);
if (childrenNum > 0) {
childrenPattern = ", " + childrenPattern;
}
String pattern = "default PatternDescriptor" + genericParam + " " + patterName + "() {\n"
+ " return new PatternDescriptor" + genericParam + "(\n"
+ " new TypePattern(" + className + ".class" + childrenPattern + "),\n"
+ " defaultPromise()\n"
+ " )" + predicate + ";\n"
+ "}\n";
generatePatterns.add(pattern);
return pattern;
}
public String generatePatterns() {
return generatePatterns.stream().collect(Collectors.joining("\n"));
}
static class EnumFieldPatternInfo {
public final String patternName;
public final String enumFullName;
public final String enumType;
public final String enumInstance;
public final String enumInstanceGetter;
public EnumFieldPatternInfo(String patternName, String enumFullName, String enumType,
String enumInstance, String enumInstanceGetter) {
this.patternName = patternName;
this.enumFullName = enumFullName;
this.enumType = enumType;
this.enumInstance = enumInstance;
this.enumInstanceGetter = enumInstanceGetter;
}
}
}