ContainDistinctFunctionRollupHandler.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.mv.rollup;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.Any;
import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import com.google.common.collect.ImmutableSet;
import java.util.Map;
import java.util.Set;
/**
* Try to roll up function which contains distinct, if the param in function is in
* materialized view group by dimension.
* For example
* materialized view def is select empid, deptno, count(salary) from distinctQuery group by empid, deptno;
* query is select deptno, count(distinct empid) from distinctQuery group by deptno;
* should rewrite successfully, count(distinct empid) should use the group by empid dimension in query.
*/
public class ContainDistinctFunctionRollupHandler extends AggFunctionRollUpHandler {
public static final ContainDistinctFunctionRollupHandler INSTANCE = new ContainDistinctFunctionRollupHandler();
public static Set<AggregateFunction> SUPPORTED_AGGREGATE_FUNCTION_SET = ImmutableSet.of(
new Max(true, Any.INSTANCE), new Min(true, Any.INSTANCE),
new Max(false, Any.INSTANCE), new Min(false, Any.INSTANCE),
new Count(true, Any.INSTANCE), new Sum(true, Any.INSTANCE),
new Avg(true, Any.INSTANCE));
@Override
public boolean canRollup(AggregateFunction queryAggregateFunction,
Expression queryAggregateFunctionShuttled,
Pair<Expression, Expression> mvExprToMvScanExprQueryBasedPair,
Map<Expression, Expression> mvExprToMvScanExprQueryBased) {
Set<AggregateFunction> queryAggregateFunctions =
queryAggregateFunctionShuttled.collectToSet(AggregateFunction.class::isInstance);
if (queryAggregateFunctions.size() > 1) {
return false;
}
for (AggregateFunction aggregateFunction : queryAggregateFunctions) {
if (SUPPORTED_AGGREGATE_FUNCTION_SET.stream()
.noneMatch(supportFunction -> Any.equals(supportFunction, aggregateFunction))) {
return false;
}
if (aggregateFunction.getArguments().size() > 1) {
return false;
}
}
Set<Expression> mvExpressionsQueryBased = mvExprToMvScanExprQueryBased.keySet();
Set<Slot> aggregateFunctionParamSlots = queryAggregateFunctionShuttled.collectToSet(Slot.class::isInstance);
if (aggregateFunctionParamSlots.stream().anyMatch(slot -> !mvExpressionsQueryBased.contains(slot))) {
// If query use any slot not in view, can not roll up
return false;
}
return true;
}
@Override
public Function doRollup(AggregateFunction queryAggregateFunction,
Expression queryAggregateFunctionShuttled, Pair<Expression, Expression> mvExprToMvScanExprQueryBasedPair,
Map<Expression, Expression> mvExprToMvScanExprQueryBasedMap) {
Expression argument = queryAggregateFunction.children().get(0);
RollupResult<Boolean> rollupResult = RollupResult.of(true);
Expression rewrittenArgument = argument.accept(new DefaultExpressionRewriter<RollupResult<Boolean>>() {
@Override
public Expression visitSlot(Slot slot, RollupResult<Boolean> context) {
if (!mvExprToMvScanExprQueryBasedMap.containsKey(slot)) {
context.param = false;
return slot;
}
return mvExprToMvScanExprQueryBasedMap.get(slot);
}
@Override
public Expression visit(Expression expr, RollupResult<Boolean> context) {
if (!context.param) {
return expr;
}
if (expr instanceof Literal || expr instanceof BinaryArithmetic || expr instanceof Slot) {
return super.visit(expr, context);
}
context.param = false;
return expr;
}
}, rollupResult);
if (!rollupResult.param) {
return null;
}
return (Function) queryAggregateFunction.withChildren(rewrittenArgument);
}
private static class RollupResult<T> {
public T param;
private RollupResult(T param) {
this.param = param;
}
public static <T> RollupResult<T> of(T param) {
return new RollupResult<>(param);
}
public T getParam() {
return param;
}
}
}