LogicalPlanBuilderForSyncMv.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.parser;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.FunctionRegistry;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.DorisParser;
import org.apache.doris.nereids.analyzer.UnboundFunction;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.BuiltinFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.commands.CreateMTMVCommand;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.PlanUtils;
import com.google.common.collect.ImmutableList;
import org.antlr.v4.runtime.ParserRuleContext;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.TreeMap;
/**
* LogicalPlanBuilderForSyncMv
*/
public class LogicalPlanBuilderForSyncMv extends LogicalPlanBuilder {
private Optional<String> querySql;
public LogicalPlanBuilderForSyncMv(Map<Integer, ParserRuleContext> selectHintMap) {
super(selectHintMap);
}
@Override
public Expression visitFunctionCallExpression(DorisParser.FunctionCallExpressionContext ctx) {
Expression expression = super.visitFunctionCallExpression(ctx);
if (expression instanceof UnboundFunction) {
return ((UnboundFunction) expression)
.withIndexInSqlString(Optional.of(new UnboundFunction.FunctionIndexInSql(
ctx.functionIdentifier().functionNameIdentifier().start.getStartIndex(),
ctx.functionIdentifier().functionNameIdentifier().stop.getStopIndex(),
ctx.stop.getStopIndex())));
} else {
return expression;
}
}
@Override
public LogicalPlan visitQuery(DorisParser.QueryContext ctx) {
LogicalPlan logicalPlan = super.visitQuery(ctx);
PlanUtils.OutermostPlanFinderContext outermostPlanFinderContext =
new PlanUtils.OutermostPlanFinderContext();
logicalPlan.accept(PlanUtils.OutermostPlanFinder.INSTANCE, outermostPlanFinderContext);
// find outermost logicalAggregate to rewrite agg_state related function
Plan outermostAgg = outermostPlanFinderContext.outermostPlan;
while (!(outermostAgg instanceof LogicalAggregate)) {
if (!outermostAgg.children().isEmpty()) {
outermostAgg = outermostAgg.child(0);
} else {
break;
}
}
String originSql = getOriginSql(ctx);
if (outermostAgg instanceof LogicalAggregate) {
List<NamedExpression> outputs = ((LogicalAggregate) outermostAgg).getOutputs();
TreeMap<Pair<Integer, Integer>, String> indexInSqlToString =
new TreeMap<>(new Pair.PairComparator<>());
AggStateFunctionFinder aggStateFunctionFinder =
new AggStateFunctionFinder(ctx.start.getStartIndex());
for (Expression expr : outputs) {
aggStateFunctionFinder.find(expr, indexInSqlToString);
}
querySql = Optional.of(rewriteSql(originSql, indexInSqlToString));
} else {
querySql = Optional.of(originSql);
}
return logicalPlan;
}
@Override
public CreateMTMVCommand visitCreateMTMV(DorisParser.CreateMTMVContext ctx) {
visitQuery(ctx.query());
return null;
}
public Optional<String> getQuerySql() {
return querySql;
}
private static class AggStateFunctionFinder
extends DefaultExpressionRewriter<TreeMap<Pair<Integer, Integer>, String>> {
private int sqlBeginIndex;
private FunctionRegistry functionRegistry;
public AggStateFunctionFinder(int sqlBeginIndex) {
this.sqlBeginIndex = sqlBeginIndex;
this.functionRegistry = Env.getCurrentEnv().getFunctionRegistry();
}
public Expression find(Expression expression,
TreeMap<Pair<Integer, Integer>, String> indexInSqlToNewString) {
return expression.accept(this, indexInSqlToNewString);
}
@Override
public Expression visitUnboundFunction(UnboundFunction unboundFunction,
TreeMap<Pair<Integer, Integer>, String> indexInSqlToNewString) {
if (unboundFunction.getFunctionIndexInSql().isPresent()) {
// try bind agg function
List<Object> arguments = unboundFunction.isDistinct()
? ImmutableList.builder().add(unboundFunction.isDistinct())
.addAll(unboundFunction.getArguments()).build()
: (List) unboundFunction.getArguments();
String functionName = unboundFunction.getName();
FunctionBuilder builder = functionRegistry
.findFunctionBuilder(unboundFunction.getDbName(), functionName, arguments);
if (builder instanceof BuiltinFunctionBuilder) {
BoundFunction boundFunction =
(BoundFunction) builder.build(functionName, arguments).first;
if (boundFunction instanceof AggregateFunction) {
// rewrite to agg_state
UnboundFunction.FunctionIndexInSql functionIndexInSql = unboundFunction
.getFunctionIndexInSql().get().indexInQueryPart(sqlBeginIndex);
functionName = boundFunction.getName();
switch (functionName) {
case "min":
case "max":
case "sum":
case "count":
case "bitmap_union":
case "hll_union": {
// no need rewrite
break;
}
default: {
indexInSqlToNewString.put(
Pair.of(functionIndexInSql.functionNameBegin,
functionIndexInSql.functionNameEnd),
String.format("%s%s(%s%s", functionName,
AggCombinerFunctionBuilder.UNION_SUFFIX,
functionName,
AggCombinerFunctionBuilder.STATE_SUFFIX));
indexInSqlToNewString
.put(Pair.of(functionIndexInSql.functionExpressionEnd,
functionIndexInSql.functionExpressionEnd), "))");
break;
}
}
}
}
}
return unboundFunction;
}
}
private static String rewriteSql(String querySql,
Map<Pair<Integer, Integer>, String> indexStringSqlMap) {
StringBuilder builder = new StringBuilder();
int beg = 0;
for (Map.Entry<Pair<Integer, Integer>, String> entry : indexStringSqlMap.entrySet()) {
Pair<Integer, Integer> index = entry.getKey();
builder.append(querySql, beg, index.first);
builder.append(entry.getValue());
beg = index.second + 1;
}
builder.append(querySql, beg, querySql.length());
return builder.toString();
}
}