NormalizeToSlot.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
import javax.annotation.Nullable;
/**
* NormalizeToSlot
*/
public interface NormalizeToSlot {
/**
* NormalizeSlotContext
*/
class NormalizeToSlotContext {
private final Map<Expression, NormalizeToSlotTriplet> normalizeToSlotMap;
public NormalizeToSlotContext(Map<Expression, NormalizeToSlotTriplet> normalizeToSlotMap) {
this.normalizeToSlotMap = normalizeToSlotMap;
}
public Map<Expression, NormalizeToSlotTriplet> getNormalizeToSlotMap() {
return normalizeToSlotMap;
}
public NormalizeToSlotContext mergeContext(NormalizeToSlotContext context) {
Map<Expression, NormalizeToSlotTriplet> newMap = Maps.newHashMap();
newMap.putAll(this.normalizeToSlotMap);
newMap.putAll(context.getNormalizeToSlotMap());
return new NormalizeToSlotContext(newMap);
}
/**
* build normalization context by follow step.
* 1. collect all exists alias by input parameters existsAliases build a reverted map: expr -> alias
* 2. for all input source expressions, use existsAliasMap to construct triple:
* origin expr, pushed expr and alias to replace origin expr,
* see more detail in {@link NormalizeToSlotTriplet}
* 3. construct a map: original expr -> triple constructed by step 2
*/
public static NormalizeToSlotContext buildContext(
Set<Alias> existsAliases, Collection<? extends Expression> sourceExpressions) {
Map<Expression, NormalizeToSlotTriplet> normalizeToSlotMap = Maps.newLinkedHashMap();
Map<Expression, Alias> existsAliasMap = Maps.newLinkedHashMap();
for (Alias existsAlias : existsAliases) {
existsAliasMap.put(existsAlias.child(), existsAlias);
}
for (Expression expression : sourceExpressions) {
if (normalizeToSlotMap.containsKey(expression)) {
continue;
}
Alias alias = null;
// consider projects: c1, c1 as a1. we should push down both of them,
// so we could not replace c1 with c1 as a1.
// use null as alias for SlotReference to avoid replace it by another alias of it.
if (!(expression instanceof SlotReference)) {
alias = existsAliasMap.get(expression);
}
NormalizeToSlotTriplet normalizeToSlotTriplet = NormalizeToSlotTriplet.toTriplet(expression, alias);
normalizeToSlotMap.put(expression, normalizeToSlotTriplet);
}
return new NormalizeToSlotContext(normalizeToSlotMap);
}
public <E extends Expression> E normalizeToUseSlotRef(E expression) {
return normalizeToUseSlotRef(ImmutableList.of(expression)).get(0);
}
/**
* normalizeToUseSlotRef, no custom normalize.
* This function use a lambda that always return original expression as customNormalize
* So always use normalizeToSlotMap to process normalization when we call this function
*/
public <E extends Expression> List<E> normalizeToUseSlotRef(Collection<E> expressions) {
return normalizeToUseSlotRef(expressions, (context, expr) -> expr);
}
/**
* normalizeToUseSlotRef.
* try to use customNormalize do normalization first. if customNormalize cannot handle current expression,
* use normalizeToSlotMap to get the default replaced expression.
*/
public <E extends Expression> List<E> normalizeToUseSlotRef(Collection<E> expressions,
BiFunction<NormalizeToSlotContext, Expression, Expression> customNormalize) {
ImmutableList.Builder<E> result = ImmutableList.builderWithExpectedSize(expressions.size());
for (E expr : expressions) {
Expression rewriteExpr = expr.rewriteDownShortCircuit(child -> {
Expression newChild = customNormalize.apply(this, child);
if (newChild != null && newChild != child) {
return newChild;
}
NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(child);
return normalizeToSlotTriplet == null ? child : normalizeToSlotTriplet.remainExpr;
});
result.add((E) rewriteExpr);
}
return result.build();
}
public <E extends Expression> List<E> normalizeToUseSlotRefWithoutWindowFunction(
Collection<E> expressions) {
ImmutableList.Builder<E> normalized = ImmutableList.builderWithExpectedSize(expressions.size());
for (E expression : expressions) {
normalized.add((E) expression.accept(NormalizeWithoutWindowFunction.INSTANCE, normalizeToSlotMap));
}
return normalized.build();
}
/**
* generate bottom projections with groupByExpressions.
* eg:
* groupByExpressions: k1#0, k2#1 + 1;
* bottom: k1#0, (k2#1 + 1) AS (k2 + 1)#2;
*/
public Set<NamedExpression> pushDownToNamedExpression(Collection<? extends Expression> needToPushExpressions) {
ImmutableSet.Builder<NamedExpression> result
= ImmutableSet.builderWithExpectedSize(needToPushExpressions.size());
for (Expression expr : needToPushExpressions) {
NormalizeToSlotTriplet normalizeToSlotTriplet = normalizeToSlotMap.get(expr);
result.add(normalizeToSlotTriplet == null
? (NamedExpression) expr
: normalizeToSlotTriplet.pushedExpr);
}
return result.build();
}
}
/**
* replace any expression except window function.
* because the window function could be same with aggregate function and should never be replaced.
*/
class NormalizeWithoutWindowFunction
extends DefaultExpressionRewriter<Map<Expression, NormalizeToSlotTriplet>> {
public static final NormalizeWithoutWindowFunction INSTANCE = new NormalizeWithoutWindowFunction();
private NormalizeWithoutWindowFunction() {
}
@Override
public Expression visit(Expression expr, Map<Expression, NormalizeToSlotTriplet> replaceMap) {
NormalizeToSlotTriplet triplet = replaceMap.get(expr);
if (triplet != null) {
return triplet.remainExpr;
}
return super.visit(expr, replaceMap);
}
@Override
public Expression visitWindow(WindowExpression windowExpression,
Map<Expression, NormalizeToSlotTriplet> replaceMap) {
NormalizeToSlotTriplet triplet = replaceMap.get(windowExpression);
if (triplet != null) {
return triplet.remainExpr;
}
ImmutableList.Builder<Expression> newChildren =
ImmutableList.builderWithExpectedSize(windowExpression.arity());
Expression function = super.visit(windowExpression.getFunction(), replaceMap);
newChildren.add(function);
boolean hasNewChildren = function != windowExpression.getFunction();
for (Expression partitionKey : windowExpression.getPartitionKeys()) {
Expression newChild = partitionKey.accept(this, replaceMap);
if (newChild != partitionKey) {
hasNewChildren = true;
}
newChildren.add(newChild);
}
for (Expression orderKey : windowExpression.getOrderKeys()) {
Expression newChild = orderKey.accept(this, replaceMap);
if (newChild != orderKey) {
hasNewChildren = true;
}
newChildren.add(newChild);
}
if (!hasNewChildren) {
return windowExpression;
}
if (windowExpression.getWindowFrame().isPresent()) {
newChildren.add(windowExpression.getWindowFrame().get());
}
return windowExpression.withChildren(newChildren.build());
}
}
/**
* NormalizeToSlotTriplet
*/
class NormalizeToSlotTriplet {
// which expression need to normalized to slot?
// e.g. `a + 1`
public final Expression originExpr;
// the slot already normalized.
// e.g. new Alias(`a + 1`).toSlot()
public final Slot remainExpr;
// the output expression need to push down to the bottom project.
// e.g. new Alias(`a + 1`)
public final NamedExpression pushedExpr;
public NormalizeToSlotTriplet(Expression originExpr, Slot remainExpr, NamedExpression pushedExpr) {
this.originExpr = originExpr;
this.remainExpr = remainExpr;
this.pushedExpr = pushedExpr;
}
/**
* construct triplet by three conditions.
* 1. already has exists alias: use this alias as pushed expr
* 2. expression is {@link NamedExpression}, use itself as pushed expr
* 3. other expression, construct a new Alias contains current expression as pushed expr
*/
public static NormalizeToSlotTriplet toTriplet(Expression expression, @Nullable Alias existsAlias) {
if (existsAlias != null) {
return new NormalizeToSlotTriplet(expression, existsAlias.toSlot(), existsAlias);
}
if (expression instanceof NamedExpression) {
NamedExpression namedExpression = (NamedExpression) expression;
return new NormalizeToSlotTriplet(expression, namedExpression.toSlot(), namedExpression);
}
Alias alias = new Alias(expression);
return new NormalizeToSlotTriplet(expression, alias.toSlot(), alias);
}
}
}