AdjustNullable.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalPartitionTopN;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation;
import org.apache.doris.nereids.trees.plans.logical.LogicalSink;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* because some rule could change output's nullable.
* So, we need add a rule to adjust all expression's nullable attribute after rewrite.
*/
public class AdjustNullable extends DefaultPlanRewriter<Map<ExprId, Slot>> implements CustomRewriter {
@Override
public Plan rewriteRoot(Plan plan, JobContext jobContext) {
return plan.accept(this, Maps.newHashMap());
}
@Override
public Plan visit(Plan plan, Map<ExprId, Slot> replaceMap) {
LogicalPlan logicalPlan = (LogicalPlan) super.visit(plan, replaceMap);
logicalPlan = logicalPlan.recomputeLogicalProperties();
for (Slot slot : logicalPlan.getOutput()) {
replaceMap.put(slot.getExprId(), slot);
}
return logicalPlan;
}
@Override
public Plan visitLogicalSink(LogicalSink<? extends Plan> logicalSink, Map<ExprId, Slot> replaceMap) {
logicalSink = (LogicalSink<? extends Plan>) super.visit(logicalSink, replaceMap);
Optional<List<NamedExpression>> newOutputExprs = updateExpressions(logicalSink.getOutputExprs(), replaceMap);
if (!newOutputExprs.isPresent()) {
return logicalSink;
} else {
return logicalSink.withOutputExprs(newOutputExprs.get());
}
}
@Override
public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate, Map<ExprId, Slot> replaceMap) {
aggregate = (LogicalAggregate<? extends Plan>) super.visit(aggregate, replaceMap);
Optional<List<NamedExpression>> newOutputs
= updateExpressions(aggregate.getOutputExpressions(), replaceMap);
Optional<List<Expression>> newGroupBy = updateExpressions(aggregate.getGroupByExpressions(), replaceMap);
for (NamedExpression newOutput : newOutputs.orElse(aggregate.getOutputExpressions())) {
replaceMap.put(newOutput.getExprId(), newOutput.toSlot());
}
if (!newOutputs.isPresent() && !newGroupBy.isPresent()) {
return aggregate;
}
return aggregate.withGroupByAndOutput(
newGroupBy.orElse(newGroupBy.orElse(aggregate.getGroupByExpressions())),
newOutputs.orElse(newOutputs.orElse(aggregate.getOutputs()))
);
}
@Override
public Plan visitLogicalFilter(LogicalFilter<? extends Plan> filter, Map<ExprId, Slot> replaceMap) {
filter = (LogicalFilter<? extends Plan>) super.visit(filter, replaceMap);
Optional<Set<Expression>> conjuncts = updateExpressions(filter.getConjuncts(), replaceMap);
if (!conjuncts.isPresent()) {
return filter;
}
return filter.withConjunctsAndChild(conjuncts.get(), filter.child());
}
@Override
public Plan visitLogicalGenerate(LogicalGenerate<? extends Plan> generate, Map<ExprId, Slot> replaceMap) {
generate = (LogicalGenerate<? extends Plan>) super.visit(generate, replaceMap);
Optional<List<Function>> newGenerators = updateExpressions(generate.getGenerators(), replaceMap);
Plan newGenerate = generate;
if (newGenerators.isPresent()) {
newGenerate = generate.withGenerators(newGenerators.get()).recomputeLogicalProperties();
}
for (Slot slot : newGenerate.getOutput()) {
replaceMap.put(slot.getExprId(), slot);
}
return newGenerate;
}
@Override
public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, Map<ExprId, Slot> replaceMap) {
join = (LogicalJoin<? extends Plan, ? extends Plan>) super.visit(join, replaceMap);
Optional<List<Expression>> hashConjuncts = updateExpressions(join.getHashJoinConjuncts(), replaceMap);
Optional<List<Expression>> markConjuncts = Optional.empty();
boolean needCheckHashConjuncts = false;
if (!hashConjuncts.isPresent() || hashConjuncts.get().isEmpty()) {
// if hashConjuncts is empty, mark join conjuncts may used to build hash table
// so need call updateExpressions for mark join conjuncts before adjust nullable by output slot
markConjuncts = updateExpressions(join.getMarkJoinConjuncts(), replaceMap);
} else {
needCheckHashConjuncts = true;
}
for (Slot slot : join.getOutput()) {
replaceMap.put(slot.getExprId(), slot);
}
if (needCheckHashConjuncts) {
// hashConjuncts is not empty, mark join conjuncts are processed like other join conjuncts
Preconditions.checkState(
!hashConjuncts.orElse(join.getHashJoinConjuncts()).isEmpty(),
"hash conjuncts should not be empty"
);
markConjuncts = updateExpressions(join.getMarkJoinConjuncts(), replaceMap);
}
Optional<List<Expression>> otherConjuncts = updateExpressions(join.getOtherJoinConjuncts(), replaceMap);
if (!hashConjuncts.isPresent() && !markConjuncts.isPresent() && !otherConjuncts.isPresent()) {
return join;
}
return join.withJoinConjuncts(
hashConjuncts.orElse(join.getHashJoinConjuncts()),
otherConjuncts.orElse(join.getOtherJoinConjuncts()),
markConjuncts.orElse(join.getMarkJoinConjuncts()),
join.getJoinReorderContext()
).recomputeLogicalProperties();
}
@Override
public Plan visitLogicalProject(LogicalProject<? extends Plan> project, Map<ExprId, Slot> replaceMap) {
project = (LogicalProject<? extends Plan>) super.visit(project, replaceMap);
Optional<List<NamedExpression>> newProjects = updateExpressions(project.getProjects(), replaceMap);
for (NamedExpression newProject : newProjects.orElse(project.getProjects())) {
replaceMap.put(newProject.getExprId(), newProject.toSlot());
}
if (!newProjects.isPresent()) {
return project;
}
return project.withProjects(newProjects.get());
}
@Override
public Plan visitLogicalRepeat(LogicalRepeat<? extends Plan> repeat, Map<ExprId, Slot> replaceMap) {
repeat = (LogicalRepeat<? extends Plan>) super.visit(repeat, replaceMap);
Set<Expression> flattenGroupingSetExpr = ImmutableSet.copyOf(
ExpressionUtils.flatExpressions(repeat.getGroupingSets()));
List<NamedExpression> newOutputs = Lists.newArrayList();
for (NamedExpression output : repeat.getOutputExpressions()) {
NamedExpression newOutput;
if (flattenGroupingSetExpr.contains(output)) {
newOutput = output;
} else {
newOutput = updateExpression(output, replaceMap).orElse(output);
}
newOutputs.add(newOutput);
replaceMap.put(newOutput.getExprId(), newOutput.toSlot());
}
return repeat.withGroupSetsAndOutput(repeat.getGroupingSets(), newOutputs).recomputeLogicalProperties();
}
@Override
public Plan visitLogicalSetOperation(LogicalSetOperation setOperation, Map<ExprId, Slot> replaceMap) {
setOperation = (LogicalSetOperation) super.visit(setOperation, replaceMap);
ImmutableList.Builder<List<SlotReference>> newChildrenOutputs = ImmutableList.builder();
List<Boolean> inputNullable = null;
if (!setOperation.children().isEmpty()) {
inputNullable = Lists.newArrayListWithCapacity(setOperation.getOutputs().size());
for (int i = 0; i < setOperation.getOutputs().size(); i++) {
inputNullable.add(false);
}
for (int i = 0; i < setOperation.arity(); i++) {
List<Slot> childOutput = setOperation.child(i).getOutput();
List<SlotReference> setChildOutput = setOperation.getRegularChildOutput(i);
ImmutableList.Builder<SlotReference> newChildOutputs = ImmutableList.builder();
for (int j = 0; j < setChildOutput.size(); j++) {
for (Slot slot : childOutput) {
if (slot.getExprId().equals(setChildOutput.get(j).getExprId())) {
inputNullable.set(j, slot.nullable() || inputNullable.get(j));
newChildOutputs.add((SlotReference) slot);
break;
}
}
}
newChildrenOutputs.add(newChildOutputs.build());
}
}
if (setOperation instanceof LogicalUnion) {
LogicalUnion logicalUnion = (LogicalUnion) setOperation;
if (!logicalUnion.getConstantExprsList().isEmpty() && setOperation.children().isEmpty()) {
int outputSize = logicalUnion.getConstantExprsList().get(0).size();
// create the inputNullable list and fill it with all FALSE values
inputNullable = Lists.newArrayListWithCapacity(outputSize);
for (int i = 0; i < outputSize; i++) {
inputNullable.add(false);
}
}
for (List<NamedExpression> constantExprs : logicalUnion.getConstantExprsList()) {
for (int j = 0; j < constantExprs.size(); j++) {
inputNullable.set(j, inputNullable.get(j) || constantExprs.get(j).nullable());
}
}
}
if (inputNullable == null) {
// this is a fail-safe
// means there is no children and having no getConstantExprsList
// no way to update the nullable flag, so just do nothing
return setOperation;
}
List<NamedExpression> outputs = setOperation.getOutputs();
List<NamedExpression> newOutputs = Lists.newArrayListWithCapacity(outputs.size());
for (int i = 0; i < inputNullable.size(); i++) {
NamedExpression ne = outputs.get(i);
Slot slot = ne instanceof Alias ? (Slot) ((Alias) ne).child() : (Slot) ne;
slot = slot.withNullable(inputNullable.get(i));
NamedExpression newOutput = ne instanceof Alias ? (NamedExpression) ne.withChildren(slot) : slot;
newOutputs.add(newOutput);
replaceMap.put(newOutput.getExprId(), newOutput.toSlot());
}
return setOperation.withNewOutputs(newOutputs)
.withChildrenAndTheirOutputs(setOperation.children(), newChildrenOutputs.build())
.recomputeLogicalProperties();
}
@Override
public Plan visitLogicalSort(LogicalSort<? extends Plan> sort, Map<ExprId, Slot> replaceMap) {
sort = (LogicalSort<? extends Plan>) super.visit(sort, replaceMap);
boolean changed = false;
ImmutableList.Builder<OrderKey> newOrderKeys = ImmutableList.builder();
for (OrderKey orderKey : sort.getOrderKeys()) {
Optional<Expression> newOrderKey = updateExpression(orderKey.getExpr(), replaceMap);
if (!newOrderKey.isPresent()) {
newOrderKeys.add(orderKey);
} else {
changed = true;
newOrderKeys.add(orderKey.withExpression(newOrderKey.get()));
}
}
if (!changed) {
return sort;
}
return sort.withOrderKeysAndChild(newOrderKeys.build(), sort.child());
}
@Override
public Plan visitLogicalTopN(LogicalTopN<? extends Plan> topN, Map<ExprId, Slot> replaceMap) {
topN = (LogicalTopN<? extends Plan>) super.visit(topN, replaceMap);
boolean changed = false;
ImmutableList.Builder<OrderKey> newOrderKeys = ImmutableList.builder();
for (OrderKey orderKey : topN.getOrderKeys()) {
Optional<Expression> newOrderKey = updateExpression(orderKey.getExpr(), replaceMap);
if (!newOrderKey.isPresent()) {
newOrderKeys.add(orderKey);
} else {
changed = true;
newOrderKeys.add(orderKey.withExpression(newOrderKey.get()));
}
}
if (!changed) {
return topN;
}
return topN.withOrderKeys(newOrderKeys.build()).recomputeLogicalProperties();
}
@Override
public Plan visitLogicalWindow(LogicalWindow<? extends Plan> window, Map<ExprId, Slot> replaceMap) {
window = (LogicalWindow<? extends Plan>) super.visit(window, replaceMap);
Optional<List<NamedExpression>> windowExpressions =
updateExpressions(window.getWindowExpressions(), replaceMap);
for (NamedExpression w : windowExpressions.orElse(window.getWindowExpressions())) {
replaceMap.put(w.getExprId(), w.toSlot());
}
if (!windowExpressions.isPresent()) {
return window;
}
return window.withExpressionsAndChild(windowExpressions.get(), window.child());
}
@Override
public Plan visitLogicalPartitionTopN(LogicalPartitionTopN<? extends Plan> partitionTopN,
Map<ExprId, Slot> replaceMap) {
partitionTopN = (LogicalPartitionTopN<? extends Plan>) super.visit(partitionTopN, replaceMap);
Optional<List<Expression>> partitionKeys = updateExpressions(partitionTopN.getPartitionKeys(), replaceMap);
Optional<List<OrderExpression>> orderKeys = updateExpressions(partitionTopN.getOrderKeys(), replaceMap);
if (!partitionKeys.isPresent() && !orderKeys.isPresent()) {
return partitionTopN;
}
return partitionTopN.withPartitionKeysAndOrderKeys(
partitionKeys.orElse(partitionTopN.getPartitionKeys()), orderKeys.orElse(partitionTopN.getOrderKeys())
);
}
@Override
public Plan visitLogicalCTEConsumer(LogicalCTEConsumer cteConsumer, Map<ExprId, Slot> replaceMap) {
Map<Slot, Slot> consumerToProducerOutputMap = new LinkedHashMap<>();
Multimap<Slot, Slot> producerToConsumerOutputMap = LinkedHashMultimap.create();
for (Slot producerOutputSlot : cteConsumer.getConsumerToProducerOutputMap().values()) {
Optional<Slot> newProducerOutputSlot = updateExpression(producerOutputSlot, replaceMap);
for (Slot consumerOutputSlot : cteConsumer.getProducerToConsumerOutputMap().get(producerOutputSlot)) {
Slot slot = newProducerOutputSlot.orElse(producerOutputSlot);
Slot newConsumerOutputSlot = consumerOutputSlot.withNullable(slot.nullable());
producerToConsumerOutputMap.put(slot, newConsumerOutputSlot);
consumerToProducerOutputMap.put(newConsumerOutputSlot, slot);
replaceMap.put(newConsumerOutputSlot.getExprId(), newConsumerOutputSlot);
}
}
return cteConsumer.withTwoMaps(consumerToProducerOutputMap, producerToConsumerOutputMap);
}
private <T extends Expression> Optional<T> updateExpression(T input, Map<ExprId, Slot> replaceMap) {
AtomicBoolean changed = new AtomicBoolean(false);
Expression replaced = input.rewriteDownShortCircuit(e -> {
if (e instanceof SlotReference) {
SlotReference slotReference = (SlotReference) e;
Slot replacedSlot = replaceMap.get(slotReference.getExprId());
if (replacedSlot != null) {
if (replacedSlot.getDataType().isAggStateType()) {
if (slotReference.nullable() != replacedSlot.nullable()
|| !slotReference.getDataType().equals(replacedSlot.getDataType())) {
// we must replace data type, because nested type and agg state contains nullable
// of their children.
// TODO: remove if statement after we ensure be constant folding do not change
// expr type at all.
changed.set(true);
return slotReference.withNullableAndDataType(
replacedSlot.nullable(), replacedSlot.getDataType()
);
}
} else if (slotReference.nullable() != replacedSlot.nullable()) {
changed.set(true);
return slotReference.withNullable(replacedSlot.nullable());
}
}
return slotReference;
} else {
return e;
}
});
return changed.get() ? Optional.of((T) replaced) : Optional.empty();
}
private <T extends Expression> Optional<List<T>> updateExpressions(List<T> inputs, Map<ExprId, Slot> replaceMap) {
ImmutableList.Builder<T> result = ImmutableList.builderWithExpectedSize(inputs.size());
boolean changed = false;
for (T input : inputs) {
Optional<T> newInput = updateExpression(input, replaceMap);
changed |= newInput.isPresent();
result.add(newInput.orElse(input));
}
return changed ? Optional.of(result.build()) : Optional.empty();
}
private <T extends Expression> Optional<Set<T>> updateExpressions(Set<T> inputs, Map<ExprId, Slot> replaceMap) {
boolean changed = false;
ImmutableSet.Builder<T> result = ImmutableSet.builderWithExpectedSize(inputs.size());
for (T input : inputs) {
Optional<T> newInput = updateExpression(input, replaceMap);
changed |= newInput.isPresent();
result.add(newInput.orElse(input));
}
return changed ? Optional.of(result.build()) : Optional.empty();
}
}