EliminateAggCaseWhen.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Filter;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.util.ExpressionUtils;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
/**
* Change argument 'case when' or 'if' inside aggregate function , to aggregate function(filter)
* example:
* select sum(case when t1.c1 = 101 then 1 END) from t1;
* ==>
* select sum(1) from t1 where t1.c1 = 101;
* note :
* only If expression is needed to process cause CaseWhenToIf have already changed case when to if
* but in sql we can still see case when so case when is reserved to explain this rule
* we can only have one output aggregate function cause of filter would influence other projection
* we can only have one aggregate function cause of filter would influence other aggregate function
* we can only have case when/if function without else cause of then can only have one branch of choice
* we can only have one case in case when cause of then can only have one branch of choice
*/
public final class EliminateAggCaseWhen extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalAggregate().then(agg -> {
Set<AggregateFunction> aggFunctions = agg.getAggregateFunctions();
// check whether we only have one aggregate function, and only one projection of aggregate function
if (aggFunctions.size() != 1 || agg.getOutputExpressions().size() != 1
|| !agg.getGroupByExpressions().isEmpty()) {
return null;
}
for (AggregateFunction aggFun : aggFunctions) {
// check whether we only have on case when/if in aggregate function
if (aggFun.getArguments().size() != 1) {
return null;
}
// only If expression is needed to process cause CaseWhenToIf have already changed case when to if
if (aggFun.getArgument(0) instanceof If) {
If anIf = (If) aggFun.getArgument(0);
if (!(anIf.getArgument(2) instanceof NullLiteral)) {
return null;
}
Expression operand = anIf.getArgument(0);
Filter filter = new LogicalFilter<>(ExpressionUtils.extractConjunctionToSet(operand), agg.child());
Expression result = anIf.getArgument(1);
Map<Expression, Expression> constantExprsReplaceMap = new HashMap<>(aggFunctions.size());
constantExprsReplaceMap.put(aggFun, ((AggregateFunction) aggFun).withChildren(result));
return agg.withChildAndOutput((Plan) filter,
ExpressionUtils.replaceNamedExpressions(
agg.getOutputExpressions(), constantExprsReplaceMap));
}
}
return null;
}).toRule(RuleType.ELIMINATE_AGG_CASE_WHEN);
}
}