TreeNode.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees;
import org.apache.doris.nereids.util.Utils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
/**
* interface for all node in Nereids, include plan node and expression.
*
* @param <NODE_TYPE> either {@link org.apache.doris.nereids.trees.plans.Plan}
* or {@link org.apache.doris.nereids.trees.expressions.Expression}
*/
public interface TreeNode<NODE_TYPE extends TreeNode<NODE_TYPE>> {
List<NODE_TYPE> children();
NODE_TYPE child(int index);
int arity();
<T> Optional<T> getMutableState(String key);
/** getOrInitMutableState */
default <T> T getOrInitMutableState(String key, Supplier<T> initState) {
Optional<T> mutableState = getMutableState(key);
if (!mutableState.isPresent()) {
T state = initState.get();
setMutableState(key, state);
return state;
}
return mutableState.get();
}
void setMutableState(String key, Object value);
default NODE_TYPE withChildren(NODE_TYPE... children) {
return withChildren(Utils.fastToImmutableList(children));
}
NODE_TYPE withChildren(List<NODE_TYPE> children);
default NODE_TYPE withChildren(Function<NODE_TYPE, NODE_TYPE> rewriter) {
return withChildren((child, index) -> rewriter.apply(child));
}
/**
* rewrite children by a rewriter
* @param rewriter consume the origin child and child index, then return the new child
* @return new tree node if any child has changed
*/
default NODE_TYPE withChildren(BiFunction<NODE_TYPE, Integer, NODE_TYPE> rewriter) {
Builder<NODE_TYPE> newChildren = ImmutableList.builderWithExpectedSize(arity());
boolean changed = false;
for (int i = 0; i < arity(); i++) {
NODE_TYPE child = child(i);
NODE_TYPE newChild = rewriter.apply(child, i);
if (child != newChild) {
changed = true;
}
newChildren.add(newChild);
}
return changed ? withChildren(newChildren.build()) : (NODE_TYPE) this;
}
/**
* top-down rewrite short circuit.
* @param rewriteFunction rewrite function.
* @return rewritten result.
*/
default NODE_TYPE rewriteDownShortCircuit(Function<NODE_TYPE, NODE_TYPE> rewriteFunction) {
NODE_TYPE currentNode = rewriteFunction.apply((NODE_TYPE) this);
if (currentNode == this) {
Builder<NODE_TYPE> newChildren = ImmutableList.builderWithExpectedSize(arity());
boolean changed = false;
for (NODE_TYPE child : children()) {
NODE_TYPE newChild = child.rewriteDownShortCircuit(rewriteFunction);
if (child != newChild) {
changed = true;
}
newChildren.add(newChild);
}
if (changed) {
currentNode = currentNode.withChildren(newChildren.build());
}
}
return currentNode;
}
/**
* similar to rewriteDownShortCircuit, except that only subtrees, whose root satisfies
* border predicate are rewritten.
*/
default NODE_TYPE rewriteDownShortCircuitDown(Function<NODE_TYPE, NODE_TYPE> rewriteFunction,
Predicate border, boolean aboveBorder) {
NODE_TYPE currentNode = (NODE_TYPE) this;
if (border.test(this)) {
aboveBorder = false;
}
if (!aboveBorder) {
currentNode = rewriteFunction.apply((NODE_TYPE) this);
}
if (currentNode == this) {
Builder<NODE_TYPE> newChildren = ImmutableList.builderWithExpectedSize(arity());
boolean changed = false;
for (NODE_TYPE child : children()) {
NODE_TYPE newChild = child.rewriteDownShortCircuitDown(rewriteFunction, border, aboveBorder);
if (child != newChild) {
changed = true;
}
newChildren.add(newChild);
}
if (changed) {
currentNode = currentNode.withChildren(newChildren.build());
}
}
return currentNode;
}
/**
* bottom-up rewrite.
* @param rewriteFunction rewrite function.
* @return rewritten result.
*/
default NODE_TYPE rewriteUp(Function<NODE_TYPE, NODE_TYPE> rewriteFunction) {
Builder<NODE_TYPE> newChildren = ImmutableList.builderWithExpectedSize(arity());
boolean changed = false;
for (NODE_TYPE child : children()) {
NODE_TYPE newChild = child.rewriteUp(rewriteFunction);
changed |= child != newChild;
newChildren.add(newChild);
}
NODE_TYPE rewrittenChildren = changed ? withChildren(newChildren.build()) : (NODE_TYPE) this;
return rewriteFunction.apply(rewrittenChildren);
}
/**
* Foreach treeNode. Top-down traverse implicitly, stop traverse if satisfy test.
* @param func foreach function
*/
default void foreach(Predicate<TreeNode<NODE_TYPE>> func) {
boolean valid = func.test(this);
if (!valid) {
for (NODE_TYPE child : children()) {
child.foreach(func);
}
}
}
/**
* Foreach treeNode. Top-down traverse implicitly.
* @param func foreach function
*/
default void foreach(Consumer<TreeNode<NODE_TYPE>> func) {
func.accept(this);
for (NODE_TYPE child : children()) {
child.foreach(func);
}
}
/** foreachBreath */
default void foreachBreath(Predicate<TreeNode<NODE_TYPE>> func) {
LinkedList<TreeNode<NODE_TYPE>> queue = new LinkedList<>();
queue.add(this);
while (!queue.isEmpty()) {
TreeNode<NODE_TYPE> current = queue.pollFirst();
if (!func.test(current)) {
queue.addAll(current.children());
}
}
}
default void foreachUp(Consumer<TreeNode<NODE_TYPE>> func) {
for (NODE_TYPE child : children()) {
child.foreach(func);
}
func.accept(this);
}
/**
* iterate top down and test predicate if any matched. Top-down traverse implicitly.
* @param predicate predicate
* @return true if any predicate return true
*/
default boolean anyMatch(Predicate<TreeNode<NODE_TYPE>> predicate) {
if (predicate.test(this)) {
return true;
}
for (NODE_TYPE child : children()) {
if (child.anyMatch(predicate)) {
return true;
}
}
return false;
}
/**
* iterate top down and test predicate if all matched. Top-down traverse implicitly.
* @param predicate predicate
* @return true if all predicate return true
*/
default boolean allMatch(Predicate<TreeNode<NODE_TYPE>> predicate) {
if (!predicate.test(this)) {
return false;
}
for (NODE_TYPE child : children()) {
if (!child.allMatch(predicate)) {
return false;
}
}
return true;
}
/**
* Collect the nodes that satisfied the predicate.
*/
default <T> Set<T> collect(Predicate<TreeNode<NODE_TYPE>> predicate) {
ImmutableSet.Builder<TreeNode<NODE_TYPE>> result = ImmutableSet.builder();
foreach(node -> {
if (predicate.test(node)) {
result.add(node);
}
});
return (Set<T>) result.build();
}
/**
* Collect the nodes that satisfied the predicate to list.
*/
default <T> List<T> collectToList(Predicate<TreeNode<NODE_TYPE>> predicate) {
ImmutableList.Builder<TreeNode<NODE_TYPE>> result = ImmutableList.builder();
foreach(node -> {
if (predicate.test(node)) {
result.add(node);
}
});
return (List<T>) result.build();
}
/**
* Collect the nodes that satisfied the predicate to set.
*/
default <T> Set<T> collectToSet(Predicate<TreeNode<NODE_TYPE>> predicate) {
ImmutableSet.Builder<TreeNode<NODE_TYPE>> result = ImmutableSet.builder();
foreach(node -> {
if (predicate.test(node)) {
result.add(node);
}
});
return (Set<T>) result.build();
}
/**
* Collect the nodes that satisfied the predicate firstly.
*/
default <T> Optional<T> collectFirst(Predicate<TreeNode<NODE_TYPE>> predicate) {
List<TreeNode<NODE_TYPE>> result = new ArrayList<>();
foreach(node -> {
if (result.isEmpty() && predicate.test(node)) {
result.add(node);
}
return !result.isEmpty();
});
return result.isEmpty() ? Optional.empty() : Optional.of((T) result.get(0));
}
/**
* iterate top down and test predicate if contains any instance of the classes
* @param types classes array
* @return true if it has any instance of the types
*/
default boolean containsType(Class... types) {
return anyMatch(node -> {
for (Class type : types) {
if (type.isInstance(node)) {
return true;
}
}
return false;
});
}
/**
* equals by the full tree nodes
* @param that other tree node
* @return true if all the tree is equals
*/
default boolean deepEquals(TreeNode<?> that) {
Deque<TreeNode<?>> thisDeque = new ArrayDeque<>();
Deque<TreeNode<?>> thatDeque = new ArrayDeque<>();
thisDeque.push(this);
thatDeque.push(that);
while (!thisDeque.isEmpty()) {
if (thatDeque.isEmpty()) {
// The "that" tree has been fully traversed, but the "this" tree has not; hence they are not equal.
return false;
}
TreeNode<?> currentNodeThis = thisDeque.pop();
TreeNode<?> currentNodeThat = thatDeque.pop();
// since TreeNode is immutable, use == to short circuit
if (currentNodeThis == currentNodeThat) {
continue;
}
// If current nodes are not equal or the number of child nodes differ, return false.
if (!currentNodeThis.equals(currentNodeThat)
|| currentNodeThis.arity() != currentNodeThat.arity()) {
return false;
}
// Add child nodes to the deque for further processing.
for (int i = 0; i < currentNodeThis.arity(); i++) {
thisDeque.push(currentNodeThis.child(i));
thatDeque.push(currentNodeThat.child(i));
}
}
// If the "that" tree hasn't been fully traversed, return false.
return thatDeque.isEmpty();
}
}