ConstantPropagation.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.properties.DataTrait;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.expression.ExpressionNormalizationAndOptimization;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.rules.FoldConstantRule;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Match;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.ComparableLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSink;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.ImmutableEqualSet;
import org.apache.doris.nereids.util.ImmutableEqualSet.Builder;
import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.nereids.util.PlanUtils;
import org.apache.doris.nereids.util.PredicateInferUtils;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.apache.hadoop.util.Lists;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
/**
* constant propagation, like: a = 10 and a + b > 30 => a = 10 and 10 + b > 30,
* when processing a plan, it will collect all its children's equal sets and constants uniforms,
* then use them and the plan's expressions to infer more equal sets and constants uniforms,
* finally use the combine uniforms to replace this plan's expression's slot with literals.
*/
public class ConstantPropagation extends DefaultPlanRewriter<ExpressionRewriteContext> implements CustomRewriter {
@Override
public Plan rewriteRoot(Plan plan, JobContext jobContext) {
// logical apply uniform maybe not correct.
if (plan.containsType(LogicalApply.class)) {
return plan;
}
ExpressionRewriteContext context = new ExpressionRewriteContext(jobContext.getCascadesContext());
return plan.accept(this, context);
}
@Override
public Plan visitLogicalFilter(LogicalFilter<? extends Plan> filter, ExpressionRewriteContext context) {
filter = visitChildren(this, filter, context);
Expression oldPredicate = filter.getPredicate();
Expression newPredicate = replaceConstantsAndRewriteExpr(filter, oldPredicate, true, context);
if (isExprEqualIgnoreOrder(oldPredicate, newPredicate)) {
return filter;
} else {
Set<Expression> newConjuncts = ImmutableSet.copyOf(ExpressionUtils.extractConjunction(newPredicate));
return filter.withConjunctsAndChild(newConjuncts, filter.child());
}
}
@Override
public Plan visitLogicalHaving(LogicalHaving<? extends Plan> having, ExpressionRewriteContext context) {
having = visitChildren(this, having, context);
Expression oldPredicate = having.getPredicate();
Expression newPredicate = replaceConstantsAndRewriteExpr(having, oldPredicate, true, context);
if (isExprEqualIgnoreOrder(oldPredicate, newPredicate)) {
return having;
} else {
Set<Expression> newConjuncts = ImmutableSet.copyOf(ExpressionUtils.extractConjunction(newPredicate));
return having.withConjunctsAndChild(newConjuncts, having.child());
}
}
@Override
public Plan visitLogicalProject(LogicalProject<? extends Plan> project, ExpressionRewriteContext context) {
project = visitChildren(this, project, context);
Pair<ImmutableEqualSet<Slot>, Map<Slot, Literal>> childEqualTrait =
getChildEqualSetAndConstants(project, context);
ImmutableList.Builder<NamedExpression> newProjectsBuilder
= ImmutableList.builderWithExpectedSize(project.getProjects().size());
for (NamedExpression expr : project.getProjects()) {
newProjectsBuilder.add(
replaceNameExpressionConstants(expr, context, childEqualTrait.first, childEqualTrait.second));
}
List<NamedExpression> newProjects = newProjectsBuilder.build();
return newProjects.equals(project.getProjects()) ? project : project.withProjects(newProjects);
}
@Override
public Plan visitLogicalSort(LogicalSort<? extends Plan> sort, ExpressionRewriteContext context) {
sort = visitChildren(this, sort, context);
Pair<ImmutableEqualSet<Slot>, Map<Slot, Literal>> childEqualTrait = getChildEqualSetAndConstants(sort, context);
// for be, order key must be a column, not a literal, so `order by 100#xx` is ok,
// but `order by 100` will make be core.
// so after replaced, we need to remove the constant expr.
ImmutableList.Builder<OrderKey> newOrderKeysBuilder
= ImmutableList.builderWithExpectedSize(sort.getOrderKeys().size());
for (OrderKey key : sort.getOrderKeys()) {
Expression newExpr = replaceConstants(key.getExpr(), false, context,
childEqualTrait.first, childEqualTrait.second);
if (!newExpr.isConstant()) {
newOrderKeysBuilder.add(key.withExpression(newExpr));
}
}
List<OrderKey> newOrderKeys = newOrderKeysBuilder.build();
if (newOrderKeys.isEmpty()) {
return sort.child();
} else if (!newOrderKeys.equals(sort.getOrderKeys())) {
return sort.withOrderKeys(newOrderKeys);
} else {
return sort;
}
}
@Override
public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate, ExpressionRewriteContext context) {
aggregate = visitChildren(this, aggregate, context);
Pair<ImmutableEqualSet<Slot>, Map<Slot, Literal>> childEqualTrait =
getChildEqualSetAndConstants(aggregate, context);
List<Expression> oldGroupByExprs = aggregate.getGroupByExpressions();
List<Expression> newGroupByExprs = Lists.newArrayListWithExpectedSize(oldGroupByExprs.size());
for (Expression expr : oldGroupByExprs) {
Expression newExpr = replaceConstants(expr, false, context, childEqualTrait.first, childEqualTrait.second);
if (!newExpr.isConstant()) {
newGroupByExprs.add(newExpr);
}
}
// group by with literal and empty group by are different.
// the former can return 0 row, the latter return at least 1 row.
// when met all group by expression are constant,
// 'eliminateGroupByConstant' will put a project(alias constant as slot) below the agg,
// but this rule cann't put a project below the agg, otherwise this rule may cause a dead loop,
// so when all replaced group by expression are constant, just let new group by add an origin group by.
if (newGroupByExprs.isEmpty() && !oldGroupByExprs.isEmpty()) {
newGroupByExprs.add(oldGroupByExprs.iterator().next());
}
Set<Expression> newGroupByExprSet = Sets.newHashSet(newGroupByExprs);
List<NamedExpression> oldOutputExprs = aggregate.getOutputExpressions();
List<NamedExpression> newOutputExprs = Lists.newArrayListWithExpectedSize(oldOutputExprs.size());
ImmutableList.Builder<NamedExpression> projectBuilder
= ImmutableList.builderWithExpectedSize(oldOutputExprs.size());
boolean containsConstantOutput = false;
// after normal agg, group by expressions and output expressions are slots,
// after this rule, they may rewrite to literal, since literal are not slot,
// we need eliminate the rewritten literals.
for (NamedExpression expr : oldOutputExprs) {
// ColumnPruning will also add all group by expression into output expressions
// agg output need contains group by expression
Expression replacedExpr = replaceConstants(expr, false, context,
childEqualTrait.first, childEqualTrait.second);
Expression newOutputExpr = newGroupByExprSet.contains(expr) ? expr : replacedExpr;
if (newOutputExpr instanceof NamedExpression) {
newOutputExprs.add((NamedExpression) newOutputExpr);
}
if (replacedExpr.isConstant()) {
projectBuilder.add(new Alias(expr.getExprId(), replacedExpr, expr.getName()));
containsConstantOutput = true;
} else {
Preconditions.checkArgument(newOutputExpr instanceof NamedExpression, newOutputExpr);
projectBuilder.add(((NamedExpression) newOutputExpr).toSlot());
}
}
if (newGroupByExprs.equals(oldGroupByExprs) && newOutputExprs.equals(oldOutputExprs)) {
return aggregate;
}
aggregate = aggregate.withGroupByAndOutput(newGroupByExprs, newOutputExprs);
if (containsConstantOutput) {
return PlanUtils.projectOrSelf(projectBuilder.build(), aggregate);
} else {
return aggregate;
}
}
@Override
public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, ExpressionRewriteContext context) {
// Combine all the join conjuncts together, may infer more constant relations.
// Then after rewrite the combine conjuncts, we need split the rewritten expression into hash/other/mark
// join conjuncts. But we can not extract the mark join conjuncts from the rewritten expression.
// So we only combine the hash conjuncts and other conjuncts.
join = visitChildren(this, join, context);
List<Expression> newHashJoinConjuncts = join.getHashJoinConjuncts();
List<Expression> newOtherJoinConjuncts = join.getOtherJoinConjuncts();
List<Expression> hashOtherConjuncts = Lists.newArrayListWithExpectedSize(
join.getHashJoinConjuncts().size() + join.getOtherJoinConjuncts().size());
hashOtherConjuncts.addAll(join.getHashJoinConjuncts());
hashOtherConjuncts.addAll(join.getOtherJoinConjuncts());
if (!hashOtherConjuncts.isEmpty()) {
Expression oldHashOtherPredicate = ExpressionUtils.and(hashOtherConjuncts);
Expression newHashOtherPredicate
= replaceConstantsAndRewriteExpr(join, oldHashOtherPredicate, true, context);
if (!isExprEqualIgnoreOrder(oldHashOtherPredicate, newHashOtherPredicate)) {
// TODO: code from FindHashConditionForJoin
Pair<List<Expression>, List<Expression>> pair
= JoinUtils.extractExpressionForHashTable(join.left().getOutput(), join.right().getOutput(),
ExpressionUtils.extractConjunction(newHashOtherPredicate));
newHashJoinConjuncts = pair.first;
newOtherJoinConjuncts = pair.second;
if (Sets.newHashSet(newHashJoinConjuncts).equals(Sets.newHashSet(join.getHashJoinConjuncts()))) {
newHashJoinConjuncts = join.getHashJoinConjuncts();
}
if (Sets.newHashSet(newOtherJoinConjuncts).equals(Sets.newHashSet(join.getOtherJoinConjuncts()))) {
newOtherJoinConjuncts = join.getOtherJoinConjuncts();
}
}
}
List<Expression> newMarkJoinConjuncts = join.getMarkJoinConjuncts();
if (!join.getMarkJoinConjuncts().isEmpty()) {
// TODO: we may extract more constant relations from hash conjuncts,
// then we may make mark join conjuncts more simplify.
Expression oldMarkPredicate = ExpressionUtils.and(join.getMarkJoinConjuncts());
Expression newMarkPredicate = replaceConstantsAndRewriteExpr(join, oldMarkPredicate, true, context);
newMarkJoinConjuncts = ExpressionUtils.extractConjunction(newMarkPredicate);
if (Sets.newHashSet(newMarkJoinConjuncts).equals(Sets.newHashSet(join.getMarkJoinConjuncts()))) {
newMarkJoinConjuncts = join.getMarkJoinConjuncts();
}
}
if (newHashJoinConjuncts.equals(join.getHashJoinConjuncts())
&& newOtherJoinConjuncts.equals(join.getOtherJoinConjuncts())
&& newMarkJoinConjuncts.equals(join.getMarkJoinConjuncts())) {
return join;
}
JoinType joinType = join.getJoinType();
if (joinType == JoinType.CROSS_JOIN && !newHashJoinConjuncts.isEmpty()) {
joinType = JoinType.INNER_JOIN;
}
return new LogicalJoin<>(joinType,
newHashJoinConjuncts,
newOtherJoinConjuncts,
newMarkJoinConjuncts,
join.getDistributeHint(),
join.getMarkJoinSlotReference(),
join.children(), join.getJoinReorderContext());
}
@Override
public Plan visitLogicalSink(LogicalSink<? extends Plan> sink, ExpressionRewriteContext context) {
sink = visitChildren(this, sink, context);
// // for sql: create table t as select cast('1' as varchar(30))
// // the select will add a parent plan: result sink. the result sink contains a output slot reference, and its
// // data type is varchar(30), but if replace the slot reference with a varchar literal '1',
// // then the data type info varchar(30) will lost, because varchar literal '1' data type is always varchar(1),
// // so t's column will get a error type. so we don't rewrite logical sink then.
// Pair<ImmutableEqualSet<Slot>, Map<Slot, Literal>> childEqualTrait
// = getChildEqualSetAndConstants(sink, context);
// List<NamedExpression> newOutputExprs = sink.getOutputExprs().stream()
// .map(expr ->
// replaceNameExpressionConstants(expr, context, childEqualTrait.first, childEqualTrait.second))
// .collect(ImmutableList.toImmutableList());
// return newOutputExprs.equals(sink.getOutputExprs()) ? sink : sink.withOutputExprs(newOutputExprs);
return sink;
}
/**
* replace constants and rewrite expression.
*/
@VisibleForTesting
public Expression replaceConstantsAndRewriteExpr(LogicalPlan plan, Expression expression,
boolean useInnerInfer, ExpressionRewriteContext context) {
// for expression `a = 1 and a + b = 2 and b + c = 2 and c + d =2 and ...`:
// propagate constant `a = 1`, then get `1 + b = 2`, after rewrite this expression, will get `b = 1`;
// then propagate constant `b = 1`, then get `1 + c = 2`, after rewrite this expression, will get `c = 1`,
// ...
// so constant propagate and rewrite expression need to do in a loop.
Pair<ImmutableEqualSet<Slot>, Map<Slot, Literal>> childEqualTrait = getChildEqualSetAndConstants(plan, context);
Expression afterExpression = expression;
for (int i = 0; i < 100; i++) {
Expression beforeExpression = afterExpression;
afterExpression = replaceConstants(beforeExpression, useInnerInfer, context,
childEqualTrait.first, childEqualTrait.second);
if (isExprEqualIgnoreOrder(beforeExpression, afterExpression)) {
break;
}
if (afterExpression.isLiteral()) {
break;
}
beforeExpression = afterExpression;
afterExpression = ExpressionNormalizationAndOptimization.NO_MIN_MAX_RANGE_INSTANCE
.rewrite(beforeExpression, context);
}
return afterExpression;
}
// process NameExpression
private NamedExpression replaceNameExpressionConstants(NamedExpression expr, ExpressionRewriteContext context,
ImmutableEqualSet<Slot> equalSet, Map<Slot, Literal> constants) {
// if a project item is a slot reference, and the slot equals to a constant value, don't rewrite it.
// because rule `EliminateUnnecessaryProject ` can eliminate a project when the project's output slots equal to
// its child's output slots.
// for example, for `sink -> ... -> project(a, b, c) -> filter(a = 10)`
// if rewrite project to project(alias 10 as a, b, c), later other rule may prune `alias 10 as a`, and project
// will become project(b, c), so project and filter's output slot will not equal,
// then the project cannot be eliminated.
// so we don't replace SlotReference.
// for safety reason, we only replace Alias
if (!(expr instanceof Alias)) {
return expr;
}
// PushProjectThroughUnion require projection is a slot reference, or like (cast slot reference as xx);
// TODO: if PushProjectThroughUnion support projection like `literal as xx`, then delete this check.
if (ExpressionUtils.getExpressionCoveredByCast(expr.child(0)) instanceof SlotReference) {
return expr;
}
Expression newExpr = replaceConstants(expr, false, context, equalSet, constants);
if (newExpr instanceof NamedExpression) {
return (NamedExpression) newExpr;
} else {
return new Alias(expr.getExprId(), newExpr, expr.getName());
}
}
private Expression replaceConstants(Expression expression, boolean useInnerInfer, ExpressionRewriteContext context,
ImmutableEqualSet<Slot> parentEqualSet, Map<Slot, Literal> parentConstants) {
if (expression instanceof And) {
return replaceAndConstants((And) expression, useInnerInfer, context, parentEqualSet, parentConstants);
} else if (expression instanceof Or) {
return replaceOrConstants((Or) expression, useInnerInfer, context, parentEqualSet, parentConstants);
} else if (!parentConstants.isEmpty()
&& expression.anyMatch(e -> e instanceof Slot && parentConstants.containsKey(e))) {
Expression newExpr = ExpressionUtils.replaceIf(expression, parentConstants, this::canReplaceExpression);
if (!newExpr.equals(expression)) {
newExpr = FoldConstantRule.evaluate(newExpr, context);
}
return newExpr;
} else {
return expression;
}
}
// process AND expression
private Expression replaceAndConstants(And expression, boolean useInnerInfer, ExpressionRewriteContext context,
ImmutableEqualSet<Slot> parentEqualSet, Map<Slot, Literal> parentConstants) {
List<Expression> conjunctions = ExpressionUtils.extractConjunction(expression);
Optional<Pair<ImmutableEqualSet<Slot>, Map<Slot, Literal>>> equalAndConstantOptions =
expandEqualSetAndConstants(conjunctions, useInnerInfer, parentEqualSet, parentConstants);
// infer conflict constants like a = 10 and a = 30, then rewrite this AND to 'FALSE'
// we have considered a may be null, the rewritten to FALSE is safe.
// the explanation can see the annotation from function expandEqualSetAndConstants
if (!equalAndConstantOptions.isPresent()) {
return BooleanLiteral.FALSE;
}
Set<Slot> inputSlots = expression.getInputSlots();
ImmutableEqualSet<Slot> newEqualSet = equalAndConstantOptions.get().first;
Map<Slot, Literal> newConstants = equalAndConstantOptions.get().second;
// myInferConstantSlots : the slots that are inferred by this expression, not inferred by parent
// myInferConstantSlots[slot] = true means expression had contains conjunct `slot = constant`
Map<Slot, Boolean> myInferConstantSlots = Maps.newLinkedHashMapWithExpectedSize(
Math.max(0, newConstants.size() - parentConstants.size()));
for (Slot slot : newConstants.keySet()) {
if (!parentConstants.containsKey(slot)) {
myInferConstantSlots.put(slot, false);
}
}
ImmutableList.Builder<Expression> builder = ImmutableList.builderWithExpectedSize(conjunctions.size());
for (Expression child : conjunctions) {
Expression newChild = child;
// for expression, `a = 10 and a > b` will infer constant relation `a = 10`,
// need to replace a with 10 to this expression,
// for the first conjunction `a = 10`, no need to replace because after replace will get `10 = 10`,
// for the second conjunction `a > b`, need replace and got `10 > b`
if (needReplaceWithConstant(newChild, newConstants, myInferConstantSlots)) {
newChild = replaceConstants(newChild, useInnerInfer, context, newEqualSet, newConstants);
}
if (newChild.equals(BooleanLiteral.FALSE)) {
return BooleanLiteral.FALSE;
}
if (newChild instanceof And) {
builder.addAll(ExpressionUtils.extractConjunction(newChild));
} else {
builder.add(newChild);
}
}
// if the expression infer `slot = constant`, but not contains conjunct `slot = constant`, need to add it
for (Map.Entry<Slot, Boolean> entry : myInferConstantSlots.entrySet()) {
// if this expression don't contain the slot, no add it, to avoid the expression size increase too long
if (!entry.getValue() && inputSlots.contains(entry.getKey())) {
Slot slot = entry.getKey();
EqualTo equal = new EqualTo(slot, newConstants.get(slot), true);
builder.add(TypeCoercionUtils.processComparisonPredicate(equal));
}
}
return expression.withChildren(builder.build());
}
// process OR expression
private Expression replaceOrConstants(Or expression, boolean useInnerInfer, ExpressionRewriteContext context,
ImmutableEqualSet<Slot> parentEqualSet, Map<Slot, Literal> parentConstants) {
List<Expression> disjunctions = ExpressionUtils.extractDisjunction(expression);
ImmutableList.Builder<Expression> builder = ImmutableList.builderWithExpectedSize(disjunctions.size());
for (Expression child : disjunctions) {
Expression newChild = replaceConstants(child, useInnerInfer, context, parentEqualSet, parentConstants);
if (newChild.equals(BooleanLiteral.TRUE)) {
return BooleanLiteral.TRUE;
}
builder.add(newChild);
}
return expression.withChildren(builder.build());
}
private boolean needReplaceWithConstant(Expression expression, Map<Slot, Literal> constants,
Map<Slot, Boolean> myInferConstantSlots) {
if (expression instanceof EqualTo && expression.child(0) instanceof Slot) {
Slot slot = (Slot) expression.child(0);
// my infer constant slots contain this slot and hadn't replaced with slot=constant yet,
// so let it alone and don't replace it.
if (myInferConstantSlots.containsKey(slot)
&& !myInferConstantSlots.get(slot)
&& expression.child(1).equals(constants.get(slot))) {
myInferConstantSlots.put(slot, true);
return false;
}
}
return true;
}
private boolean canReplaceExpression(Expression expression) {
// 'a is not null', EliminateOuterJoin will call TypeUtils.isNotNull
if (ExpressionUtils.isGeneratedNotNull(expression)) {
return false;
}
// "https://doris.apache.org/docs/sql-manual/basic-element/operators/conditional-operators
// /full-text-search-operators", the match function require left is a slot, not a literal.
if (expression instanceof Match) {
return false;
}
// EliminateJoinByFK, join with materialize view need keep `a = b`.
// But for a common join, need eliminate `a = b`, after eliminate hash join equations,
// hash join will change to nested loop join.
// Join with materialize view will no more need `a = b` later.
// TODO: replace constants with `a = b`
if (expression instanceof EqualPredicate
&& expression.child(0) instanceof SlotReference
&& expression.child(1) instanceof SlotReference) {
SlotReference left = (SlotReference) expression.child(0);
SlotReference right = (SlotReference) expression.child(1);
return left.getQualifier().equals(right.getQualifier());
}
return true;
}
private Pair<ImmutableEqualSet<Slot>, Map<Slot, Literal>> getChildEqualSetAndConstants(
LogicalPlan plan, ExpressionRewriteContext context) {
if (plan.children().size() == 1) {
DataTrait dataTrait = plan.child(0).getLogicalProperties().getTrait();
return Pair.of(dataTrait.getEqualSet(), getConstantUniforms(dataTrait.getAllUniformValues(), context));
} else {
Map<Slot, Literal> uniformConstants = Maps.newHashMap();
ImmutableEqualSet.Builder<Slot> newEqualSetBuilder = new Builder<>();
for (Plan child : plan.children()) {
uniformConstants.putAll(
getConstantUniforms(child.getLogicalProperties().getTrait().getAllUniformValues(), context));
newEqualSetBuilder.addEqualSet(child.getLogicalProperties().getTrait().getEqualSet());
}
return Pair.of(newEqualSetBuilder.build(), uniformConstants);
}
}
private Map<Slot, Literal> getConstantUniforms(Map<Slot, Optional<Expression>> uniformValues,
ExpressionRewriteContext context) {
Map<Slot, Literal> constantValues = Maps.newHashMap();
for (Map.Entry<Slot, Optional<Expression>> entry : uniformValues.entrySet()) {
Expression expr = entry.getValue().isPresent() ? entry.getValue().get() : null;
if (expr == null || !expr.isConstant()) {
continue;
}
if (!expr.isLiteral()) {
// uniforms values contains like 'cast(11 as smallint)'
expr = FoldConstantRule.evaluate(expr, context);
if (!expr.isLiteral()) {
continue;
}
}
constantValues.put(entry.getKey(), (Literal) expr);
}
return constantValues;
}
/**
*
* Extract equal set and constants from conjunctions, then combine them with parentEqualSet and parentConstants.
* If met conflict constants relation, return optional.empty().
*/
private Optional<Pair<ImmutableEqualSet<Slot>, Map<Slot, Literal>>> expandEqualSetAndConstants(
List<Expression> conjunctions,
boolean useInnerInfer,
ImmutableEqualSet<Slot> parentEqualSet,
Map<Slot, Literal> parentConstants) {
// infer conflict constants like a = 10 and a = 30, then rewrite this AND to 'FALSE'
// we think this is conflict only when:
// 1) a is not nullable;
// 2) a is nullable, and it must satisfy the below:
// i) the expression is FILTER or HAVING or JOIN condition, so for a PROJECT `a = 10 and a = 30`
// will not evaluate to FALSE
// ii) this expression is an expression root, or all its ancestors are AND/OR,
// and then, this expression can never evaluate to 'TRUE', and can safe replace with 'FALSE'.
// for example, for a FILTER expression: '(a = 10 and a = 20) or not(b = 10 and b = 20)',
// 'a = 10 and a = 20' can evaluate to FALSE because its ancestors {OR} are all AND/OR,
// but 'b=10 and b=20' can not evaluate to FALSE because its ancestors {NOT, OR} contains NOT.
//
// so how to achieve this ?
// use arg useInnerInfer means whether extract constant
// from the rewritten expression itself, for example, for a=10:
// 1) if a is not nullable, extract it;
// 2) if a is nullable, then check useInnerInfer, only useInnerInfer=true, can extract it.
// i) only FILTER/HAVING/JOIN, useInnerInfer may be true, and other plan will use useInnerInfer=false
// ii) replace the expression from top to down, and split the expression into AND / OR sub expressions,
// for '(a = 10 and a = 20) or not(b = 10 and b = 20)'
// first split the OR into two sub expression: (a=10 and a=20), not(b=10 and b=20)
// a = 10 and a = 20 is AND, then recurse call replaceAndConstant to handle it,
// not(b = 10 and b = 20) is not AND/OR, then just replace the expression with upper constants,
// and will not extract b = 10 and b = 20
Map<Slot, Literal> newConstants = Maps.newLinkedHashMapWithExpectedSize(parentConstants.size());
newConstants.putAll(parentConstants);
ImmutableEqualSet.Builder<Slot> newEqualSetBuilder = new Builder<>(parentEqualSet);
for (Expression child : conjunctions) {
Optional<Pair<Slot, Expression>> equalItem = findValidEqualItem(child);
if (!equalItem.isPresent()) {
continue;
}
Slot slot = equalItem.get().first;
Expression expr = equalItem.get().second;
// for expression `a = 1 and a * 3 = 10`
// if it's in LogicalFilter, then we can infer constant relation `a = 1`, then have:
// `a = 1 and 1 * 3 = 10` => `a = 1 and FALSE` => `FALSE`
// but if it's in LogicalProject, then we shouldn't propagate `a=1` to this expression,
// because a may be null, then `a=1 and a * 3 = 10` should evaluate to `NULL` in the project.
if (!useInnerInfer && (slot.nullable() || expr.nullable())) {
continue;
}
if (expr instanceof Slot) {
newEqualSetBuilder.addEqualPair(slot, (Slot) expr);
} else if (!addConstant(newConstants, slot, (Literal) expr)) {
return Optional.empty();
}
}
ImmutableEqualSet<Slot> newEqualSet = newEqualSetBuilder.build();
List<Set<Slot>> multiEqualSlots = newEqualSet.calEqualSetList();
for (Set<Slot> slots : multiEqualSlots) {
Slot slot = null;
for (Slot s : slots) {
if (newConstants.containsKey(s)) {
slot = s;
break;
}
}
if (slot == null) {
continue;
}
Literal value = newConstants.get(slot);
for (Slot s : slots) {
if (!addConstant(newConstants, s, value)) {
return Optional.empty();
}
}
}
return Optional.of(Pair.of(newEqualSet, newConstants));
}
// add a unique constant, if a slot have two different constants value, add fail.
// for example: a = 10 and a = 20, when add 'a = 10', return true,
// but later add 'a = 20' will meet a conflict and return false
private boolean addConstant(Map<Slot, Literal> constants, Slot slot, Literal value) {
Literal existValue = constants.get(slot);
if (existValue == null) {
constants.put(slot, value);
return true;
}
// value equals existsValue, or compare them return 0
return value.equals(existValue)
|| (value instanceof ComparableLiteral
&& existValue instanceof ComparableLiteral
&& ((ComparableLiteral) value).compareTo((ComparableLiteral) existValue) == 0);
}
private Optional<Pair<Slot, Expression>> findValidEqualItem(Expression expression) {
if (!(expression instanceof EqualPredicate)) {
return Optional.empty();
}
Expression left = expression.child(0);
Expression right = expression.child(1);
if (!PredicateInferUtils.isSlotOrNotNullLiteral(left) || !PredicateInferUtils.isSlotOrNotNullLiteral(right)) {
return Optional.empty();
}
if (left instanceof Slot) {
return Optional.of(Pair.of((Slot) left, right));
} else if (right instanceof Slot) {
return Optional.of(Pair.of((Slot) right, left));
} else {
return Optional.empty();
}
}
private boolean isExprEqualIgnoreOrder(Expression oldExpr, Expression newExpr) {
if (oldExpr instanceof And) {
return Sets.newHashSet(ExpressionUtils.extractConjunction(oldExpr))
.equals(Sets.newHashSet(ExpressionUtils.extractConjunction(newExpr)));
} else if (oldExpr instanceof Or) {
return Sets.newHashSet(ExpressionUtils.extractDisjunction(oldExpr))
.equals(Sets.newHashSet(ExpressionUtils.extractDisjunction(newExpr)));
} else {
return oldExpr.equals(newExpr);
}
}
}