SplitMultiDistinct.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.rules.rewrite.SplitMultiDistinct.DistinctSplitContext;
import org.apache.doris.nereids.trees.copier.DeepCopierContext;
import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.SupportMultiDistinct;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.ImmutableList;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/**
 * LogicalAggregate(output:count(distinct a) as c1, count(distinct b) as c2)
 *   +--Plan
 * ->
 * LogicalCTEAnchor
 *   +--LogicalCTEProducer
 *     +--Plan
 *   +--LogicalProject(c1, c2)
 *     +--LogicalJoin
 *       +--LogicalAggregate(output:count(distinct a))
 *         +--LogicalCTEConsumer
 *       +--LogicalAggregate(output:count(distinct b))
 *         +--LogicalCTEConsumer
 * */
public class SplitMultiDistinct extends DefaultPlanRewriter<DistinctSplitContext> implements CustomRewriter {
    public static SplitMultiDistinct INSTANCE = new SplitMultiDistinct();

    /**DistinctSplitContext*/
    public static class DistinctSplitContext {
        List<LogicalCTEProducer<? extends Plan>> cteProducerList;
        StatementContext statementContext;
        CascadesContext cascadesContext;

        public DistinctSplitContext(StatementContext statementContext, CascadesContext cascadesContext) {
            this.statementContext = statementContext;
            this.cteProducerList = new ArrayList<>();
            this.cascadesContext = cascadesContext;
        }
    }

    @Override
    public Plan rewriteRoot(Plan plan, JobContext jobContext) {
        DistinctSplitContext ctx = new DistinctSplitContext(
                jobContext.getCascadesContext().getStatementContext(), jobContext.getCascadesContext());
        plan = plan.accept(this, ctx);
        for (int i = ctx.cteProducerList.size() - 1; i >= 0; i--) {
            LogicalCTEProducer<? extends Plan> producer = ctx.cteProducerList.get(i);
            plan = new LogicalCTEAnchor<>(producer.getCteId(), producer, plan);
        }
        return plan;
    }

    @Override
    public Plan visitLogicalCTEAnchor(
            LogicalCTEAnchor<? extends Plan, ? extends Plan> anchor, DistinctSplitContext ctx) {
        Plan child1 = anchor.child(0).accept(this, ctx);
        DistinctSplitContext consumerContext =
                new DistinctSplitContext(ctx.statementContext, ctx.cascadesContext);
        Plan child2 = anchor.child(1).accept(this, consumerContext);
        for (int i = consumerContext.cteProducerList.size() - 1; i >= 0; i--) {
            LogicalCTEProducer<? extends Plan> producer = consumerContext.cteProducerList.get(i);
            child2 = new LogicalCTEAnchor<>(producer.getCteId(), producer, child2);
        }
        return anchor.withChildren(ImmutableList.of(child1, child2));
    }

    @Override
    public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> agg, DistinctSplitContext ctx) {
        Plan newChild = agg.child().accept(this, ctx);
        agg = agg.withChildren(ImmutableList.of(newChild));
        List<Alias> distinctFuncWithAlias = new ArrayList<>();
        List<Alias> otherAggFuncs = new ArrayList<>();
        if (!needTransform((LogicalAggregate<Plan>) agg, distinctFuncWithAlias, otherAggFuncs)) {
            return agg;
        }

        LogicalAggregate<Plan> cloneAgg = (LogicalAggregate<Plan>) LogicalPlanDeepCopier.INSTANCE
                .deepCopy(agg, new DeepCopierContext());
        LogicalCTEProducer<Plan> producer = new LogicalCTEProducer<>(ctx.statementContext.getNextCTEId(),
                cloneAgg.child());
        ctx.cteProducerList.add(producer);
        Map<Slot, Slot> originToProducerSlot = new HashMap<>();
        for (int i = 0; i < agg.child().getOutput().size(); ++i) {
            Slot originSlot = agg.child().getOutput().get(i);
            Slot cloneSlot = cloneAgg.child().getOutput().get(i);
            originToProducerSlot.put(originSlot, cloneSlot);
        }
        distinctFuncWithAlias = ExpressionUtils.replace((List) distinctFuncWithAlias, originToProducerSlot);
        otherAggFuncs = ExpressionUtils.replace((List) otherAggFuncs, originToProducerSlot);
        // construct cte consumer and aggregate
        List<LogicalAggregate<Plan>> newAggs = new ArrayList<>();
        // All otherAggFuncs are placed in the first one
        Map<Alias, Alias> newToOriginDistinctFuncAlias = new HashMap<>();
        List<Expression> outputJoinGroupBys = new ArrayList<>();
        for (int i = 0; i < distinctFuncWithAlias.size(); ++i) {
            Expression distinctAggFunc = distinctFuncWithAlias.get(i).child(0);
            Map<Slot, Slot> producerToConsumerSlotMap = new HashMap<>();
            List<NamedExpression> outputExpressions = new ArrayList<>();
            List<Expression> replacedGroupBy = new ArrayList<>();
            LogicalCTEConsumer consumer = constructConsumerAndReplaceGroupBy(ctx, producer, cloneAgg, outputExpressions,
                    producerToConsumerSlotMap, replacedGroupBy);
            Expression newDistinctAggFunc = ExpressionUtils.replace(distinctAggFunc, producerToConsumerSlotMap);
            Alias alias = new Alias(newDistinctAggFunc);
            outputExpressions.add(alias);
            if (i == 0) {
                // save replacedGroupBy
                outputJoinGroupBys.addAll(replacedGroupBy);
            }
            LogicalAggregate<Plan> newAgg = new LogicalAggregate<>(replacedGroupBy, outputExpressions, consumer);
            newAggs.add(newAgg);
            newToOriginDistinctFuncAlias.put(alias, distinctFuncWithAlias.get(i));
        }
        buildOtherAggFuncAggregate(otherAggFuncs, producer, ctx, cloneAgg, newToOriginDistinctFuncAlias, newAggs);
        List<Expression> groupBy = agg.getGroupByExpressions();
        LogicalJoin<Plan, Plan> join = constructJoin(newAggs, groupBy);
        return constructProject(groupBy, newToOriginDistinctFuncAlias, outputJoinGroupBys, join);
    }

    private static void buildOtherAggFuncAggregate(List<Alias> otherAggFuncs, LogicalCTEProducer<Plan> producer,
            DistinctSplitContext ctx, LogicalAggregate<Plan> cloneAgg, Map<Alias, Alias> newToOriginDistinctFuncAlias,
            List<LogicalAggregate<Plan>> newAggs) {
        if (otherAggFuncs.isEmpty()) {
            return;
        }
        Map<Slot, Slot> producerToConsumerSlotMap = new HashMap<>();
        List<NamedExpression> outputExpressions = new ArrayList<>();
        List<Expression> replacedGroupBy = new ArrayList<>();
        LogicalCTEConsumer consumer = constructConsumerAndReplaceGroupBy(ctx, producer, cloneAgg, outputExpressions,
                producerToConsumerSlotMap, replacedGroupBy);
        List<Expression> otherAggFuncAliases = otherAggFuncs.stream()
                .map(e -> ExpressionUtils.replace(e, producerToConsumerSlotMap)).collect(Collectors.toList());
        for (Expression otherAggFuncAlias : otherAggFuncAliases) {
            // otherAggFunc is instance of Alias
            Alias outputOtherFunc = new Alias(otherAggFuncAlias.child(0));
            outputExpressions.add(outputOtherFunc);
            newToOriginDistinctFuncAlias.put(outputOtherFunc, (Alias) otherAggFuncAlias);
        }
        LogicalAggregate<Plan> newAgg = new LogicalAggregate<>(replacedGroupBy, outputExpressions, consumer);
        newAggs.add(newAgg);
    }

    private static LogicalCTEConsumer constructConsumerAndReplaceGroupBy(DistinctSplitContext ctx,
            LogicalCTEProducer<Plan> producer, LogicalAggregate<Plan> cloneAgg, List<NamedExpression> outputExpressions,
            Map<Slot, Slot> producerToConsumerSlotMap, List<Expression> replacedGroupBy) {
        LogicalCTEConsumer consumer = new LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
                producer.getCteId(), "", producer);
        ctx.cascadesContext.putCTEIdToConsumer(consumer);
        for (Map.Entry<Slot, Slot> entry : consumer.getConsumerToProducerOutputMap().entrySet()) {
            producerToConsumerSlotMap.put(entry.getValue(), entry.getKey());
        }
        replacedGroupBy.addAll(ExpressionUtils.replace(cloneAgg.getGroupByExpressions(), producerToConsumerSlotMap));
        outputExpressions.addAll(replacedGroupBy.stream().map(Slot.class::cast).collect(Collectors.toList()));
        return consumer;
    }

    private static boolean isDistinctMultiColumns(AggregateFunction func) {
        if (func.arity() <= 1) {
            return false;
        }
        for (int i = 1; i < func.arity(); ++i) {
            // think about group_concat(distinct col_1, ',')
            if (!(func.child(i) instanceof OrderExpression) && !func.child(i).getInputSlots().isEmpty()) {
                return true;
            }
        }
        return false;
    }

    private static boolean needTransform(LogicalAggregate<Plan> agg, List<Alias> aliases, List<Alias> otherAggFuncs) {
        // TODO with source repeat aggregate need to be supported in future
        if (agg.getSourceRepeat().isPresent()) {
            return false;
        }
        Set<Expression> distinctFunc = new HashSet<>();
        boolean distinctMultiColumns = false;
        for (NamedExpression namedExpression : agg.getOutputExpressions()) {
            if (!(namedExpression instanceof Alias) || !(namedExpression.child(0) instanceof AggregateFunction)) {
                continue;
            }
            AggregateFunction aggFunc = (AggregateFunction) namedExpression.child(0);
            if (aggFunc instanceof SupportMultiDistinct && aggFunc.isDistinct()) {
                aliases.add((Alias) namedExpression);
                distinctFunc.add(aggFunc);
                distinctMultiColumns = distinctMultiColumns || isDistinctMultiColumns(aggFunc);
            } else {
                otherAggFuncs.add((Alias) namedExpression);
            }
        }
        if (distinctFunc.size() <= 1) {
            return false;
        }
        // when this aggregate is not distinctMultiColumns, and group by expressions is not empty
        // e.g. sql1: select count(distinct a), count(distinct b) from t1 group by c;
        // sql2: select count(distinct a) from t1 group by c;
        // the physical plan of sql1 and sql2 is similar, both are 2-phase aggregate,
        // so there is no need to do this rewrite
        if (!distinctMultiColumns && !agg.getGroupByExpressions().isEmpty()) {
            return false;
        }
        return true;
    }

    private static LogicalProject<Plan> constructProject(List<Expression> groupBy, Map<Alias, Alias> joinOutput,
            List<Expression> outputJoinGroupBys, LogicalJoin<Plan, Plan> join) {
        List<NamedExpression> projects = new ArrayList<>();
        for (Map.Entry<Alias, Alias> entry : joinOutput.entrySet()) {
            projects.add(new Alias(entry.getValue().getExprId(), entry.getKey().toSlot(), entry.getValue().getName()));
        }
        // outputJoinGroupBys.size() == agg.getGroupByExpressions().size()
        for (int i = 0; i < groupBy.size(); ++i) {
            Slot slot = (Slot) groupBy.get(i);
            projects.add(new Alias(slot.getExprId(), outputJoinGroupBys.get(i), slot.getName()));
        }
        return new LogicalProject<>(projects, join);
    }

    private static LogicalJoin<Plan, Plan> constructJoin(List<LogicalAggregate<Plan>> newAggs,
            List<Expression> groupBy) {
        LogicalJoin<Plan, Plan> join;
        if (groupBy.isEmpty()) {
            join = new LogicalJoin<>(JoinType.CROSS_JOIN, newAggs.get(0), newAggs.get(1), null);
            for (int j = 2; j < newAggs.size(); ++j) {
                join = new LogicalJoin<>(JoinType.CROSS_JOIN, join, newAggs.get(j), null);
            }
        } else {
            int len = groupBy.size();
            List<Slot> leftSlots = newAggs.get(0).getOutput();
            List<Slot> rightSlots = newAggs.get(1).getOutput();
            List<Expression> hashConditions = new ArrayList<>();
            for (int i = 0; i < len; ++i) {
                hashConditions.add(new NullSafeEqual(leftSlots.get(i), rightSlots.get(i)));
            }
            join = new LogicalJoin<>(JoinType.INNER_JOIN, hashConditions, newAggs.get(0), newAggs.get(1), null);
            for (int j = 2; j < newAggs.size(); ++j) {
                List<Slot> belowJoinSlots = join.left().getOutput();
                List<Slot> belowRightSlots = newAggs.get(j).getOutput();
                List<Expression> aboveHashConditions = new ArrayList<>();
                for (int i = 0; i < len; ++i) {
                    aboveHashConditions.add(new NullSafeEqual(belowJoinSlots.get(i), belowRightSlots.get(i)));
                }
                join = new LogicalJoin<>(JoinType.INNER_JOIN, aboveHashConditions, join, newAggs.get(j), null);
            }
        }
        return join;
    }
}