ColumnPruning.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.rules.rewrite.ColumnPruning.PruneContext;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
import org.apache.doris.nereids.trees.plans.logical.LogicalExcept;
import org.apache.doris.nereids.trees.plans.logical.LogicalIntersect;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.logical.LogicalSink;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
import org.apache.doris.nereids.trees.plans.logical.OutputPrunable;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.types.TinyIntType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.qe.ConnectContext;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;

import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.IntStream;

/**
 * ColumnPruning.
 *
 * you should implement OutputPrunable for your plan to provide the ability of column pruning
 *
 * functions:
 *
 * 1. prune/shrink output field for OutputPrunable, e.g.
 *
 *            project(projects=[k1, sum(v1)])                              project(projects=[k1, sum(v1)])
 *                      |                                 ->                      |
 *    agg(groupBy=[k1], output=[k1, sum(v1), sum(v2)]                  agg(groupBy=[k1], output=[k1, sum(v1)])
 *
 * 2. add project for the plan which prune children's output failed, e.g. the filter not record
 *    the output, and we can not prune/shrink output field for the filter, so we should add project on filter.
 *
 *          agg(groupBy=[a])                              agg(groupBy=[a])
 *                |                                              |
 *           filter(b > 10)                ->                project(a)
 *                |                                              |
 *              plan                                       filter(b > 10)
 *                                                               |
 *                                                              plan
 */
public class ColumnPruning extends DefaultPlanRewriter<PruneContext> implements CustomRewriter {
    private Set<Slot> keys;

    /**
     * collect all columns used in expressions, which should not be pruned
     * the purpose to collect keys are:
     * 1. used for count(*), '*' is replaced by the smallest(data type in byte size) column
     * 2. for StatsDerive, only when col-stats of keys are not available, we fall back to no-stats algorithm
     */
    public static class KeyColumnCollector
            extends DefaultPlanRewriter<JobContext> implements CustomRewriter {
        public Set<Slot> keys = Sets.newHashSet();

        @Override
        public Plan rewriteRoot(Plan plan, JobContext jobContext) {
            return plan.accept(this, jobContext);
        }

        @Override
        public Plan visit(Plan plan, JobContext jobContext) {
            for (Plan child : plan.children()) {
                child.accept(this, jobContext);
            }
            for (Expression expression : plan.getExpressions()) {
                if (!(expression instanceof SlotReference)) {
                    keys.addAll(expression.getInputSlots());
                }
            }
            return plan;
        }

        @Override
        public LogicalAggregate<? extends Plan> visitLogicalAggregate(LogicalAggregate<? extends Plan> agg,
                JobContext jobContext) {
            agg.child().accept(this, jobContext);
            for (Expression expression : agg.getExpressions()) {
                if (expression instanceof SlotReference) {
                    keys.add((Slot) expression);
                } else {
                    keys.addAll(expression.getInputSlots());
                }
            }
            return agg;
        }
    }

    @Override
    public Plan rewriteRoot(Plan plan, JobContext jobContext) {
        KeyColumnCollector keyColumnCollector = new KeyColumnCollector();
        plan.accept(keyColumnCollector, jobContext);
        keys = keyColumnCollector.keys;
        if (ConnectContext.get() != null) {
            StatementContext stmtContext = ConnectContext.get().getStatementContext();
            // in ut, stmtContext is null
            if (stmtContext != null) {
                for (Slot key : keys) {
                    if (key instanceof SlotReference) {
                        stmtContext.addKeySlot((SlotReference) key);
                    }
                }
            }
        }

        return plan.accept(this, new PruneContext(plan.getOutputSet(), null));
    }

    @Override
    public Plan visit(Plan plan, PruneContext context) {
        if (plan instanceof OutputPrunable) {
            // the case 1 in the class comment
            // two steps: prune current output and prune children
            OutputPrunable outputPrunable = (OutputPrunable) plan;
            plan = pruneOutput(plan, outputPrunable.getOutputs(), outputPrunable::pruneOutputs, context);
            return pruneChildren(plan);
        } else {
            // e.g.
            //
            //       project(a)
            //           |
            //           |  require: [a]
            //           v
            //       filter(b > 1)    <-  process currently
            //           |
            //           |  require: [a, b]
            //           v
            //       child plan
            //
            // the filter is not OutputPrunable, we should pass through the parent required slots
            // (slot a, which in the context.requiredSlots) and the used slots currently(slot b) to child plan.
            return pruneChildren(plan, context.requiredSlots);
        }
    }

    // union can not prune children by the common logic, we must override visit method to write special code.
    @Override
    public Plan visitLogicalUnion(LogicalUnion union, PruneContext context) {
        if (union.getQualifier() == Qualifier.DISTINCT) {
            return skipPruneThisAndFirstLevelChildren(union);
        }
        LogicalUnion prunedOutputUnion = pruneUnionOutput(union, context);
        // start prune children of union
        List<Slot> originOutput = union.getOutput();
        Set<Slot> prunedOutput = prunedOutputUnion.getOutputSet();
        List<Integer> prunedOutputIndexes = IntStream.range(0, originOutput.size())
                .filter(index -> prunedOutput.contains(originOutput.get(index)))
                .boxed()
                .collect(ImmutableList.toImmutableList());

        ImmutableList.Builder<Plan> prunedChildren = ImmutableList.builder();
        ImmutableList.Builder<List<SlotReference>> prunedChildrenOutputs = ImmutableList.builder();
        for (int i = 0; i < prunedOutputUnion.arity(); i++) {
            List<SlotReference> regularChildOutputs = prunedOutputUnion.getRegularChildOutput(i);
            List<SlotReference> prunedChildOutput = prunedOutputIndexes.stream()
                    .map(regularChildOutputs::get)
                    .collect(ImmutableList.toImmutableList());
            Set<Slot> prunedChildOutputSet = ImmutableSet.copyOf(prunedChildOutput);
            Plan prunedChild = doPruneChild(prunedOutputUnion, prunedOutputUnion.child(i), prunedChildOutputSet);
            prunedChildrenOutputs.add(prunedChildOutput);
            prunedChildren.add(prunedChild);
        }
        return prunedOutputUnion.withChildrenAndTheirOutputs(prunedChildren.build(), prunedChildrenOutputs.build());
    }

    // we should keep the output of LogicalSetOperation and all the children
    @Override
    public Plan visitLogicalExcept(LogicalExcept except, PruneContext context) {
        return skipPruneThisAndFirstLevelChildren(except);
    }

    @Override
    public Plan visitLogicalIntersect(LogicalIntersect intersect, PruneContext context) {
        return skipPruneThisAndFirstLevelChildren(intersect);
    }

    @Override
    public Plan visitLogicalSink(LogicalSink<? extends Plan> logicalSink, PruneContext context) {
        return pruneChildren(logicalSink, logicalSink.getOutputSet());
    }

    // the backend not support filter(project(agg)), so we can not prune the key set in the agg,
    // only prune the agg functions here
    @Override
    public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate, PruneContext context) {
        return pruneAggregate(aggregate, context);
    }

    // same as aggregate
    @Override
    public Plan visitLogicalRepeat(LogicalRepeat<? extends Plan> repeat, PruneContext context) {
        return pruneAggregate(repeat, context);
    }

    @Override
    public Plan visitLogicalCTEProducer(LogicalCTEProducer<? extends Plan> cteProducer, PruneContext context) {
        return skipPruneThisAndFirstLevelChildren(cteProducer);
    }

    @Override
    public Plan visitLogicalCTEConsumer(LogicalCTEConsumer cteConsumer, PruneContext context) {
        return super.visitLogicalCTEConsumer(cteConsumer, context);
    }

    @Override
    public Plan visitLogicalWindow(LogicalWindow<? extends Plan> window, PruneContext context) {
        boolean pruned = false;
        boolean reserved = false;
        ImmutableList.Builder<NamedExpression> reservedWindowExpressions = ImmutableList.builder();
        for (NamedExpression windowExpression : window.getWindowExpressions()) {
            if (context.requiredSlots.contains(windowExpression.toSlot())) {
                reservedWindowExpressions.add(windowExpression);
                reserved = true;
            } else {
                pruned = true;
            }
        }
        if (!pruned) {
            return pruneChildren(window, context.requiredSlots);
        }
        if (!reserved) {
            return window.child().accept(this, context);
        }
        LogicalWindow<? extends Plan> prunedWindow
                = window.withExpressionsAndChild(reservedWindowExpressions.build(), window.child());
        return pruneChildren(prunedWindow, context.requiredSlots);
    }

    private Plan pruneAggregate(Aggregate<?> agg, PruneContext context) {
        // first try to prune group by and aggregate functions
        Aggregate<? extends Plan> prunedOutputAgg = pruneOutput(agg, agg.getOutputs(), agg::pruneOutputs, context);
        Aggregate<?> fillUpAggregate = fillUpGroupByAndOutput(prunedOutputAgg);
        return pruneChildren(fillUpAggregate);
    }

    private Plan skipPruneThisAndFirstLevelChildren(Plan plan) {
        ImmutableSet.Builder<Slot> requireAllOutputOfChildren = ImmutableSet.builder();
        for (Plan child : plan.children()) {
            requireAllOutputOfChildren.addAll(child.getOutput());
        }
        return pruneChildren(plan, requireAllOutputOfChildren.build());
    }

    private static Aggregate<? extends Plan> fillUpGroupByAndOutput(Aggregate<? extends Plan> prunedOutputAgg) {
        List<Expression> groupBy = prunedOutputAgg.getGroupByExpressions();
        List<NamedExpression> output = prunedOutputAgg.getOutputExpressions();

        if (!(prunedOutputAgg instanceof LogicalAggregate)) {
            return prunedOutputAgg;
        }

        ImmutableList.Builder<NamedExpression> newOutputListBuilder
                = ImmutableList.builderWithExpectedSize(output.size());
        newOutputListBuilder.addAll((List) groupBy);
        for (NamedExpression ne : output) {
            if (!groupBy.contains(ne)) {
                newOutputListBuilder.add(ne);
            }
        }

        List<NamedExpression> newOutputList = newOutputListBuilder.build();
        Set<AggregateFunction> aggregateFunctions = prunedOutputAgg.getAggregateFunctions();
        ImmutableList.Builder<Expression> newGroupByExprList
                = ImmutableList.builderWithExpectedSize(newOutputList.size());
        for (NamedExpression e : newOutputList) {
            if (!(e instanceof Alias && aggregateFunctions.contains(e.child(0)))) {
                newGroupByExprList.add(e);
            }
        }
        return ((LogicalAggregate<? extends Plan>) prunedOutputAgg).withGroupByAndOutput(
                newGroupByExprList.build(), newOutputList);
    }

    /** prune output */
    public <P extends Plan> P pruneOutput(P plan, List<NamedExpression> originOutput,
            Function<List<NamedExpression>, P> withPrunedOutput, PruneContext context) {
        if (originOutput.isEmpty()) {
            return plan;
        }
        List<NamedExpression> prunedOutputs =
                Utils.filterImmutableList(originOutput, output -> context.requiredSlots.contains(output.toSlot()));

        if (prunedOutputs.isEmpty()) {
            List<NamedExpression> candidates = Lists.newArrayList(originOutput);
            candidates.retainAll(keys);
            if (candidates.isEmpty()) {
                candidates = originOutput;
            }
            NamedExpression minimumColumn = ExpressionUtils.selectMinimumColumn(candidates);
            prunedOutputs = ImmutableList.of(minimumColumn);
        }

        if (prunedOutputs.equals(originOutput)) {
            return plan;
        } else {
            return withPrunedOutput.apply(prunedOutputs);
        }
    }

    private LogicalUnion pruneUnionOutput(LogicalUnion union, PruneContext context) {
        List<NamedExpression> originOutput = union.getOutputs();
        if (originOutput.isEmpty()) {
            return union;
        }
        List<NamedExpression> prunedOutputs = Lists.newArrayList();
        List<List<NamedExpression>> constantExprsList = union.getConstantExprsList();
        List<List<SlotReference>> regularChildrenOutputs = union.getRegularChildrenOutputs();
        List<Plan> children = union.children();
        List<Integer> extractColumnIndex = Lists.newArrayList();
        for (int i = 0; i < originOutput.size(); i++) {
            NamedExpression output = originOutput.get(i);
            if (context.requiredSlots.contains(output.toSlot())) {
                prunedOutputs.add(output);
                extractColumnIndex.add(i);
            }
        }

        ImmutableList.Builder<List<NamedExpression>> prunedConstantExprsList
                = ImmutableList.builderWithExpectedSize(constantExprsList.size());
        if (prunedOutputs.isEmpty()) {
            // process prune all columns
            NamedExpression originSlot = originOutput.get(0);
            prunedOutputs = ImmutableList.of(new SlotReference(originSlot.getExprId(), originSlot.getName(),
                    TinyIntType.INSTANCE, false, originSlot.getQualifier()));
            regularChildrenOutputs = Lists.newArrayListWithCapacity(regularChildrenOutputs.size());
            children = Lists.newArrayListWithCapacity(children.size());
            for (int i = 0; i < union.getArity(); i++) {
                LogicalProject<?> project = new LogicalProject<>(
                        ImmutableList.of(new Alias(new TinyIntLiteral((byte) 1))), union.child(i));
                regularChildrenOutputs.add((List) project.getOutput());
                children.add(project);
            }
            for (int i = 0; i < constantExprsList.size(); i++) {
                prunedConstantExprsList.add(ImmutableList.of(new Alias(new TinyIntLiteral((byte) 1))));
            }
        } else {
            int len = extractColumnIndex.size();
            for (List<NamedExpression> row : constantExprsList) {
                ImmutableList.Builder<NamedExpression> newRow = ImmutableList.builderWithExpectedSize(len);
                for (int idx : extractColumnIndex) {
                    newRow.add(row.get(idx));
                }
                prunedConstantExprsList.add(newRow.build());
            }
        }

        if (prunedOutputs.equals(originOutput) && !context.requiredSlots.isEmpty()) {
            return union;
        } else {
            return union.withNewOutputsChildrenAndConstExprsList(prunedOutputs, children,
                    regularChildrenOutputs, prunedConstantExprsList.build());
        }
    }

    private <P extends Plan> P pruneChildren(P plan) {
        return pruneChildren(plan, ImmutableSet.of());
    }

    private <P extends Plan> P pruneChildren(P plan, Set<Slot> parentRequiredSlots) {
        if (plan.arity() == 0) {
            // leaf
            return plan;
        }

        Set<Slot> currentUsedSlots = plan.getInputSlots();
        Set<Slot> childrenRequiredSlots = parentRequiredSlots.isEmpty()
                ? currentUsedSlots
                : ImmutableSet.<Slot>builderWithExpectedSize(parentRequiredSlots.size() + currentUsedSlots.size())
                        .addAll(parentRequiredSlots)
                        .addAll(currentUsedSlots)
                        .build();

        ImmutableList.Builder<Plan> newChildren = ImmutableList.builderWithExpectedSize(plan.arity());
        boolean hasNewChildren = false;
        for (Plan child : plan.children()) {
            Set<Slot> childRequiredSlots;
            List<Slot> childOutputs = child.getOutput();
            ImmutableSet.Builder<Slot> childRequiredSlotBuilder
                    = ImmutableSet.builderWithExpectedSize(childOutputs.size());
            for (Slot childOutput : childOutputs) {
                if (childrenRequiredSlots.contains(childOutput)) {
                    childRequiredSlotBuilder.add(childOutput);
                }
            }
            childRequiredSlots = childRequiredSlotBuilder.build();
            Plan prunedChild = doPruneChild(plan, child, childRequiredSlots);
            if (prunedChild != child) {
                hasNewChildren = true;
            }
            newChildren.add(prunedChild);
        }
        return hasNewChildren ? (P) plan.withChildren(newChildren.build()) : plan;
    }

    private Plan doPruneChild(Plan plan, Plan child, Set<Slot> childRequiredSlots) {
        if (child instanceof LogicalCTEProducer) {
            return child;
        }
        boolean isProject = plan instanceof LogicalProject;
        Plan prunedChild = child.accept(this, new PruneContext(childRequiredSlots, plan));

        // the case 2 in the class comment, prune child's output failed
        if (!isProject && !Sets.difference(prunedChild.getOutputSet(), childRequiredSlots).isEmpty()) {
            prunedChild = new LogicalProject<>(Utils.fastToImmutableList(childRequiredSlots), prunedChild);
        }
        return prunedChild;
    }

    /** PruneContext */
    public static class PruneContext {
        public Set<Slot> requiredSlots;
        public Optional<Plan> parent;

        public PruneContext(Set<Slot> requiredSlots, Plan parent) {
            this.requiredSlots = requiredSlots;
            this.parent = Optional.ofNullable(parent);
        }
    }
}