AddProjectForUniqueFunction.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.scalar.UniqueFunction;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate;
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.JoinUtils;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
/** extract unique function expression which exist multiple times, and add them to a new project child.
* for example:
* before rewrite: filter(random() >= 5 and random() <= 10), suppose the two random have the same unique expr id.
* after rewrite: filter(k >= 5 and k <= 10) -> project(random() as k)
*/
public class AddProjectForUniqueFunction implements RewriteRuleFactory {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
new GenerateRewrite().build(),
new OneRowRelationRewrite().build(),
new ProjectRewrite().build(),
new FilterRewrite().build(),
new HavingRewrite().build(),
new AggregateRewrite().build(),
new JoinRewrite().build()
);
}
private class GenerateRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalGenerate().thenApply(ctx -> {
LogicalGenerate<Plan> generate = ctx.root;
Optional<Pair<List<Function>, LogicalProject<Plan>>>
rewrittenOpt = rewriteExpressions(generate, generate.getGenerators());
if (rewrittenOpt.isPresent()) {
return generate.withGenerators(rewrittenOpt.get().first)
.withChildren(rewrittenOpt.get().second);
} else {
return generate;
}
}).toRule(RuleType.ADD_PROJECT_FOR_UNIQUE_FUNCTION);
}
}
private class OneRowRelationRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalOneRowRelation().thenApply(ctx -> {
LogicalOneRowRelation oneRowRelation = ctx.root;
List<NamedExpression> uniqueFunctionAlias = tryGenUniqueFunctionAlias(oneRowRelation.getProjects());
if (uniqueFunctionAlias.isEmpty()) {
return oneRowRelation;
}
Map<Expression, Slot> replaceMap = Maps.newHashMap();
for (NamedExpression alias : uniqueFunctionAlias) {
replaceMap.put(alias.child(0), alias.toSlot());
}
ImmutableList.Builder<NamedExpression> newProjectBuilder
= ImmutableList.builderWithExpectedSize(oneRowRelation.getProjects().size());
for (NamedExpression expr : oneRowRelation.getProjects()) {
newProjectBuilder.add((NamedExpression) ExpressionUtils.replace(expr, replaceMap));
}
return new LogicalProject<>(
newProjectBuilder.build(),
oneRowRelation.withProjects(uniqueFunctionAlias));
}).toRule(RuleType.ADD_PROJECT_FOR_UNIQUE_FUNCTION);
}
}
private class ProjectRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalProject().thenApply(ctx -> {
LogicalProject<Plan> project = ctx.root;
Optional<Pair<List<NamedExpression>, LogicalProject<Plan>>>
rewrittenOpt = rewriteExpressions(project, project.getProjects());
if (rewrittenOpt.isPresent()) {
return project.withProjectsAndChild(rewrittenOpt.get().first, rewrittenOpt.get().second);
} else {
return project;
}
}).toRule(RuleType.ADD_PROJECT_FOR_UNIQUE_FUNCTION);
}
}
private class FilterRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalFilter().thenApply(ctx -> {
LogicalFilter<Plan> filter = ctx.root;
Optional<Pair<List<Expression>, LogicalProject<Plan>>>
rewrittenOpt = rewriteExpressions(filter, filter.getConjuncts());
if (rewrittenOpt.isPresent()) {
return filter.withConjunctsAndChild(
ImmutableSet.copyOf(rewrittenOpt.get().first),
rewrittenOpt.get().second);
} else {
return filter;
}
}).toRule(RuleType.ADD_PROJECT_FOR_UNIQUE_FUNCTION);
}
}
private class HavingRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalHaving().thenApply(ctx -> {
LogicalHaving<Plan> having = ctx.root;
Optional<Pair<List<Expression>, LogicalProject<Plan>>>
rewrittenOpt = rewriteExpressions(having, having.getConjuncts());
if (rewrittenOpt.isPresent()) {
return having.withConjuncts(ImmutableSet.copyOf(rewrittenOpt.get().first))
.withChildren(rewrittenOpt.get().second);
} else {
return having;
}
}).toRule(RuleType.ADD_PROJECT_FOR_UNIQUE_FUNCTION);
}
}
private class AggregateRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalAggregate().thenApply(ctx -> {
LogicalAggregate<Plan> aggregate = ctx.root;
List<Expression> targets = Lists.newArrayList();
targets.addAll(aggregate.getGroupByExpressions());
targets.addAll(aggregate.getOutputExpressions());
Optional<Pair<List<Expression>, LogicalProject<Plan>>> rewrittenOpt
= rewriteExpressions(aggregate, targets);
if (!rewrittenOpt.isPresent()) {
return aggregate;
}
LogicalProject<Plan> newChild = rewrittenOpt.get().second;
List<Expression> newTargets = rewrittenOpt.get().first;
int groupBySize = aggregate.getGroupByExpressions().size();
ImmutableList<Expression> newGroupBy = ImmutableList.copyOf(
newTargets.subList(0, groupBySize));
ImmutableList.Builder<NamedExpression> newOutputBuilder
= ImmutableList.builderWithExpectedSize(aggregate.getOutputExpressions().size());
for (int i = groupBySize; i < newTargets.size(); i++) {
newOutputBuilder.add((NamedExpression) newTargets.get(i));
}
return aggregate.withChildGroupByAndOutput(newGroupBy, newOutputBuilder.build(), newChild);
}).toRule(RuleType.ADD_PROJECT_FOR_UNIQUE_FUNCTION);
}
}
private class JoinRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalJoin().thenApply(ctx -> {
LogicalJoin<Plan, Plan> join = ctx.root;
int hashOtherConjunctsSize = join.getHashJoinConjuncts().size() + join.getOtherJoinConjuncts().size();
int totalConjunctsSize = hashOtherConjunctsSize + join.getMarkJoinConjuncts().size();
List<Expression> allConjuncts = Lists.newArrayListWithExpectedSize(totalConjunctsSize);
allConjuncts.addAll(join.getHashJoinConjuncts());
allConjuncts.addAll(join.getOtherJoinConjuncts());
allConjuncts.addAll(join.getMarkJoinConjuncts());
Optional<Pair<List<Expression>, LogicalProject<Plan>>> rewrittenOpt
= rewriteExpressions(join, allConjuncts);
if (!rewrittenOpt.isPresent()) {
return join;
}
LogicalProject<Plan> newLeftChild = rewrittenOpt.get().second;
List<Expression> newAllConjuncts = rewrittenOpt.get().first;
List<Expression> newHashOtherConjuncts = newAllConjuncts.subList(0, hashOtherConjunctsSize);
List<Expression> newMarkJoinConjuncts = ImmutableList.copyOf(
newAllConjuncts.subList(hashOtherConjunctsSize, totalConjunctsSize));
// TODO: code from FindHashConditionForJoin
Pair<List<Expression>, List<Expression>> pair = JoinUtils.extractExpressionForHashTable(
newLeftChild.getOutput(), join.right().getOutput(), newHashOtherConjuncts);
List<Expression> newHashJoinConjuncts = pair.first;
List<Expression> newOtherJoinConjuncts = pair.second;
JoinType joinType = join.getJoinType();
if (joinType == JoinType.CROSS_JOIN && !newHashJoinConjuncts.isEmpty()) {
joinType = JoinType.INNER_JOIN;
}
return new LogicalJoin<>(joinType,
newHashJoinConjuncts,
newOtherJoinConjuncts,
newMarkJoinConjuncts,
join.getDistributeHint(),
join.getMarkJoinSlotReference(),
ImmutableList.of(newLeftChild, join.right()),
join.getJoinReorderContext());
}).toRule(RuleType.ADD_PROJECT_FOR_UNIQUE_FUNCTION);
}
}
/**
* extract unique function which exist multiple times from targets,
* then alias the unique function and put them into a child project,
* then rewrite targets with the alias names.
*/
@VisibleForTesting
public <T extends Expression> Optional<Pair<List<T>, LogicalProject<Plan>>> rewriteExpressions(
LogicalPlan plan, Collection<T> targets) {
List<NamedExpression> uniqueFunctionAlias = tryGenUniqueFunctionAlias(targets);
if (uniqueFunctionAlias.isEmpty()) {
return Optional.empty();
}
List<NamedExpression> projects = ImmutableList.<NamedExpression>builder()
.addAll(plan.child(0).getOutputSet())
.addAll(uniqueFunctionAlias)
.build();
Map<Expression, Slot> replaceMap = Maps.newHashMap();
for (NamedExpression alias : uniqueFunctionAlias) {
replaceMap.put(alias.child(0), alias.toSlot());
}
ImmutableList.Builder<T> newTargetsBuilder = ImmutableList.builderWithExpectedSize(targets.size());
for (T target : targets) {
newTargetsBuilder.add((T) ExpressionUtils.replace(target, replaceMap));
}
return Optional.of(Pair.of(newTargetsBuilder.build(), new LogicalProject<>(projects, plan.child(0))));
}
/**
* if a unique function exists multiple times in the targets, then add a project to alias it.
*/
@VisibleForTesting
public List<NamedExpression> tryGenUniqueFunctionAlias(Collection<? extends Expression> targets) {
Map<UniqueFunction, Integer> unqiueFunctionCounter = Maps.newLinkedHashMap();
for (Expression target : targets) {
target.foreach(e -> {
Expression expr = (Expression) e;
if (expr instanceof UniqueFunction) {
unqiueFunctionCounter.merge((UniqueFunction) expr, 1, Integer::sum);
}
});
}
ImmutableList.Builder<NamedExpression> builder
= ImmutableList.builderWithExpectedSize(unqiueFunctionCounter.size());
for (Entry<UniqueFunction, Integer> entry : unqiueFunctionCounter.entrySet()) {
if (entry.getValue() > 1) {
ExprId exprId = StatementScopeIdGenerator.newExprId();
String name = "$_" + entry.getKey().getName() + "_" + exprId.asInt() + "_$";
builder.add(new Alias(exprId, entry.getKey(), name));
}
}
return builder.build();
}
}