Pattern.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.pattern;
import org.apache.doris.nereids.trees.AbstractTreeNode;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Objects;
import java.util.function.Predicate;
/**
* Pattern node used in pattern matching.
*/
public class Pattern<TYPE extends Plan>
extends AbstractTreeNode<Pattern<? extends Plan>> {
public static final Pattern ANY = new Pattern(PatternType.ANY);
public static final Pattern MULTI = new Pattern(PatternType.MULTI);
public static final Pattern GROUP = new Pattern(PatternType.GROUP);
public static final Pattern MULTI_GROUP = new Pattern(PatternType.MULTI_GROUP);
protected final List<Predicate<TYPE>> predicates;
protected final PatternType patternType;
protected final PlanType planType;
public Pattern(PlanType planType, Pattern... children) {
this(PatternType.NORMAL, planType, children);
}
public Pattern(PlanType planType, List<Predicate<TYPE>> predicates, Pattern... children) {
this(PatternType.NORMAL, planType, predicates, children);
}
private Pattern(PatternType patternType, Pattern... children) {
this(patternType, PlanType.UNKNOWN, children);
}
/**
* Constructor for Pattern.
*
* @param patternType pattern type to matching
* @param children sub pattern
*/
private Pattern(PatternType patternType, PlanType planType, Pattern... children) {
super(children);
this.patternType = patternType;
this.planType = planType;
this.predicates = ImmutableList.of();
}
/**
* Constructor for Pattern.
*
* @param patternType pattern type to matching
* @param planType plan type to matching
* @param predicates custom matching predicate
* @param children sub pattern
*/
protected Pattern(PatternType patternType, PlanType planType,
List<Predicate<TYPE>> predicates, Pattern... children) {
super(children);
this.patternType = patternType;
this.planType = planType;
this.predicates = ImmutableList.copyOf(predicates);
for (int i = 0; i + 1 < children.length; ++i) {
if (children[i].isMulti()) {
throw new IllegalStateException("Pattern.MULTI must be last child of current pattern");
} else if (children[i].isMultiGroup()) {
throw new IllegalStateException("Pattern.MULTI_GROUP must be last child of current pattern");
}
}
}
/**
* get current type in Plan.
*
* @return plan type in pattern
*/
public PlanType getPlanType() {
return planType;
}
/**
* get current type in Pattern.
*
* @return pattern type
*/
public PatternType getPatternType() {
return patternType;
}
/**
* get all predicates in Pattern.
*
* @return all predicates
*/
public List<Predicate<TYPE>> getPredicates() {
return predicates;
}
public boolean isGroup() {
return patternType == PatternType.GROUP;
}
public boolean isMultiGroup() {
return patternType == PatternType.MULTI_GROUP;
}
public boolean isAny() {
return patternType == PatternType.ANY;
}
public boolean isMulti() {
return patternType == PatternType.MULTI;
}
/** matchPlan */
public boolean matchPlanTree(Plan plan) {
if (!matchRoot(plan)) {
return false;
}
int childPatternNum = arity();
if (childPatternNum != plan.arity() && childPatternNum > 0 && child(childPatternNum - 1) != MULTI) {
return false;
}
switch (patternType) {
case ANY:
case MULTI:
return matchPredicates((TYPE) plan);
default:
}
if (this instanceof SubTreePattern) {
return matchPredicates((TYPE) plan);
}
return matchChildrenAndSelfPredicates(plan, childPatternNum);
}
private boolean matchChildrenAndSelfPredicates(Plan plan, int childPatternNum) {
List<Plan> childrenPlan = plan.children();
for (int i = 0; i < childrenPlan.size(); i++) {
Plan child = childrenPlan.get(i);
Pattern childPattern = child(Math.min(i, childPatternNum - 1));
if (!childPattern.matchPlanTree(child)) {
return false;
}
}
return matchPredicates((TYPE) plan);
}
/**
* Return ture if current Pattern match Plan in params.
*
* @param plan wait to match
* @return ture if current Pattern match Plan in params
*/
public boolean matchRoot(Plan plan) {
if (plan == null) {
return false;
}
switch (patternType) {
case ANY:
case MULTI:
case GROUP:
case MULTI_GROUP:
return true;
default:
return planType == plan.getType();
}
}
/**
* match all predicates.
* @param root root plan
* @return true if all predicates matched
*/
public boolean matchPredicates(TYPE root) {
// use loop to speed up
for (Predicate<TYPE> predicate : predicates) {
if (!predicate.test(root)) {
return false;
}
}
return true;
}
@Override
public Pattern<? extends Plan> withChildren(
List<Pattern<? extends Plan>> children) {
throw new IllegalStateException("Pattern can not invoke withChildren");
}
public Pattern<TYPE> withPredicates(List<Predicate<TYPE>> predicates) {
return new Pattern(patternType, planType, predicates, children.toArray(new Pattern[0]));
}
public boolean hasMultiChild() {
return !children.isEmpty() && children.get(children.size() - 1).isMulti();
}
public boolean hasMultiGroupChild() {
return !children.isEmpty() && children.get(children.size() - 1).isMultiGroup();
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
Pattern<?> pattern = (Pattern<?>) o;
return predicates.equals(pattern.predicates)
&& patternType == pattern.patternType
&& planType == pattern.planType;
}
@Override
public int hashCode() {
return Objects.hash(predicates, patternType, planType);
}
}