ExpressionLineageReplacer.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.plans.visitor;
import org.apache.doris.catalog.TableIf.TableType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.visitor.ExpressionLineageReplacer.ExpressionReplaceContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
/**
* ExpressionLineageReplacer
* Get from rewrite plan and can also get from plan struct info, if from plan struct info it depends on
* the nodes from graph.
*/
public class ExpressionLineageReplacer extends DefaultPlanVisitor<Expression, ExpressionReplaceContext> {
public static final Logger LOG = LogManager.getLogger(ExpressionLineageReplacer.class);
public static final ExpressionLineageReplacer INSTANCE = new ExpressionLineageReplacer();
@Override
public Expression visit(Plan plan, ExpressionReplaceContext context) {
List<? extends Expression> expressions = plan.getExpressions();
Map<ExprId, Expression> targetExpressionMap = context.getExprIdExpressionMap();
// Filter the namedExpression used by target and collect the namedExpression
expressions.stream()
.filter(expression -> expression instanceof NamedExpression
&& targetExpressionMap.containsKey(((NamedExpression) expression).getExprId()))
.forEach(expression -> expression.accept(NamedExpressionCollector.INSTANCE, context));
return super.visit(plan, context);
}
@Override
public Expression visitGroupPlan(GroupPlan groupPlan, ExpressionReplaceContext context) {
LOG.error("ExpressionLineageReplacer should not meet groupPlan, plan is {}", groupPlan.toString());
return null;
}
/**
* Replace the expression with lineage according the exprIdExpressionMap
*/
public static class ExpressionReplacer extends DefaultExpressionRewriter<Map<ExprId, Expression>> {
public static final ExpressionReplacer INSTANCE = new ExpressionReplacer();
@Override
public Expression visitNamedExpression(NamedExpression namedExpression,
Map<ExprId, Expression> exprIdExpressionMap) {
if (exprIdExpressionMap.containsKey(namedExpression.getExprId())) {
return visit(exprIdExpressionMap.get(namedExpression.getExprId()), exprIdExpressionMap);
}
return visit(namedExpression, exprIdExpressionMap);
}
@Override
public Expression visit(Expression expr, Map<ExprId, Expression> exprIdExpressionMap) {
if (expr instanceof NamedExpression
&& expr.arity() == 0
&& exprIdExpressionMap.containsKey(((NamedExpression) expr).getExprId())) {
expr = exprIdExpressionMap.get(((NamedExpression) expr).getExprId());
}
List<Expression> newChildren = new ArrayList<>(expr.arity());
boolean hasNewChildren = false;
for (Expression child : expr.children()) {
Expression newChild = child.accept(this, exprIdExpressionMap);
if (newChild != child) {
hasNewChildren = true;
}
newChildren.add(newChild);
}
return hasNewChildren ? expr.withChildren(newChildren) : expr;
}
}
/**
* The Collector for named expressions in the whole plan, and will be used to
* replace the target expression later
* TODO Collect named expression by targetTypes, tableIdentifiers
*/
public static class NamedExpressionCollector
extends DefaultExpressionVisitor<Void, ExpressionReplaceContext> {
public static final NamedExpressionCollector INSTANCE = new NamedExpressionCollector();
@Override
public Void visitSlot(Slot slot, ExpressionReplaceContext context) {
context.getExprIdExpressionMap().put(slot.getExprId(), slot);
return super.visit(slot, context);
}
@Override
public Void visitAlias(Alias alias, ExpressionReplaceContext context) {
// remove the alias
if (context.getExprIdExpressionMap().containsKey(alias.getExprId())) {
context.getExprIdExpressionMap().put(alias.getExprId(), alias.child());
}
return super.visitAlias(alias, context);
}
}
/**
* The context for replacing the expression with lineage
*/
public static class ExpressionReplaceContext {
private final List<Expression> targetExpressions;
private final Set<TableType> targetTypes;
private final Set<String> tableIdentifiers;
private final Map<ExprId, Expression> exprIdExpressionMap;
private final BitSet tableBitSet;
private List<Expression> replacedExpressions;
/**
* ExpressionReplaceContext
*/
public ExpressionReplaceContext(List<Expression> targetExpressions,
Set<TableType> targetTypes,
Set<String> tableIdentifiers,
BitSet tableBitSet) {
this.targetExpressions = targetExpressions;
this.targetTypes = targetTypes;
this.tableIdentifiers = tableIdentifiers;
this.tableBitSet = tableBitSet;
// collect the named expressions used in target expression and will be replaced later
this.exprIdExpressionMap = targetExpressions.stream()
.map(each -> each.collectToList(NamedExpression.class::isInstance))
.flatMap(Collection::stream)
.map(NamedExpression.class::cast)
.distinct()
.collect(Collectors.toMap(NamedExpression::getExprId, expr -> expr));
}
public List<Expression> getTargetExpressions() {
return targetExpressions;
}
public Set<TableType> getTargetTypes() {
return targetTypes;
}
public Set<String> getTableIdentifiers() {
return tableIdentifiers;
}
public Map<ExprId, Expression> getExprIdExpressionMap() {
return exprIdExpressionMap;
}
public BitSet getTableBitSet() {
return tableBitSet;
}
/**
* getReplacedExpressions
*/
public List<Expression> getReplacedExpressions() {
if (this.replacedExpressions == null) {
this.replacedExpressions = targetExpressions.stream()
.map(original -> original.accept(ExpressionReplacer.INSTANCE, getExprIdExpressionMap()))
.collect(Collectors.toList());
}
return this.replacedExpressions;
}
}
}