RewriteCteChildren.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.annotation.DependsRules;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.executor.Rewriter;
import org.apache.doris.nereids.jobs.rewrite.RewriteJob;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.expressions.CTEId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;

import java.util.HashSet;
import java.util.List;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/**
 * rewrite CteAnchor consumer side and producer side recursively, all CteAnchor must at top of the plan
 */
@DependsRules({PullUpCteAnchor.class, CTEInline.class})
public class RewriteCteChildren extends DefaultPlanRewriter<CascadesContext> implements CustomRewriter {

    private final List<RewriteJob> jobs;

    public RewriteCteChildren(List<RewriteJob> jobs) {
        this.jobs = jobs;
    }

    @Override
    public Plan rewriteRoot(Plan plan, JobContext jobContext) {
        return plan.accept(this, jobContext.getCascadesContext());
    }

    @Override
    public Plan visit(Plan plan, CascadesContext context) {
        Rewriter.getCteChildrenRewriter(context, jobs).execute();
        return context.getRewritePlan();
    }

    @Override
    public Plan visitLogicalCTEAnchor(LogicalCTEAnchor<? extends Plan, ? extends Plan> cteAnchor,
            CascadesContext cascadesContext) {
        LogicalPlan outer;
        if (cascadesContext.getStatementContext().getRewrittenCteConsumer().containsKey(cteAnchor.getCteId())) {
            outer = cascadesContext.getStatementContext().getRewrittenCteConsumer().get(cteAnchor.getCteId());
        } else {
            CascadesContext outerCascadesCtx = CascadesContext.newSubtreeContext(
                    Optional.empty(), cascadesContext, cteAnchor.child(1),
                    cascadesContext.getCurrentJobContext().getRequiredProperties());
            outer = (LogicalPlan) cteAnchor.child(1).accept(this, outerCascadesCtx);
            cascadesContext.getStatementContext().getRewrittenCteConsumer().put(cteAnchor.getCteId(), outer);
        }
        Set<LogicalCTEConsumer> cteConsumers = Sets.newHashSet();
        outer.foreach(p -> {
            if (p instanceof LogicalCTEConsumer) {
                LogicalCTEConsumer logicalCTEConsumer = (LogicalCTEConsumer) p;
                if (logicalCTEConsumer.getCteId().equals(cteAnchor.getCteId())) {
                    cteConsumers.add(logicalCTEConsumer);
                }
            }
            return false;
        });
        cascadesContext.getCteIdToConsumers().put(cteAnchor.getCteId(), cteConsumers);
        if (cteConsumers.isEmpty()) {
            return outer;
        }
        Plan producer = cteAnchor.child(0).accept(this, cascadesContext);
        return cteAnchor.withChildren(producer, outer);
    }

    @Override
    public Plan visitLogicalCTEProducer(LogicalCTEProducer<? extends Plan> cteProducer,
            CascadesContext cascadesContext) {
        LogicalPlan child;
        if (cascadesContext.getStatementContext().getRewrittenCteProducer().containsKey(cteProducer.getCteId())) {
            child = cascadesContext.getStatementContext().getRewrittenCteProducer().get(cteProducer.getCteId());
        } else {
            child = (LogicalPlan) cteProducer.child();
            child = tryToConstructFilter(cascadesContext, cteProducer.getCteId(), child);
            Set<Slot> producerOutputs = cascadesContext.getStatementContext()
                    .getCteIdToOutputIds().get(cteProducer.getCteId());
            if (producerOutputs.size() < child.getOutput().size()) {
                ImmutableList.Builder<NamedExpression> projectsBuilder
                        = ImmutableList.builderWithExpectedSize(producerOutputs.size());
                for (Slot slot : child.getOutput()) {
                    if (producerOutputs.contains(slot)) {
                        projectsBuilder.add(slot);
                    }
                }
                child = new LogicalProject<>(projectsBuilder.build(), child);
                child = pushPlanUnderAnchor(child);
            }
            CascadesContext rewrittenCtx = CascadesContext.newSubtreeContext(
                    Optional.of(cteProducer.getCteId()), cascadesContext, child, PhysicalProperties.ANY);
            child = (LogicalPlan) child.accept(this, rewrittenCtx);
            cascadesContext.getStatementContext().getRewrittenCteProducer().put(cteProducer.getCteId(), child);
        }
        return cteProducer.withChildren(child);
    }

    private LogicalPlan pushPlanUnderAnchor(LogicalPlan plan) {
        if (plan.child(0) instanceof LogicalCTEAnchor) {
            LogicalPlan child = (LogicalPlan) plan.withChildren(plan.child(0).child(1));
            return (LogicalPlan) plan.child(0).withChildren(
                    plan.child(0).child(0), pushPlanUnderAnchor(child));
        }
        return plan;
    }

    /*
     * An expression can only be pushed down if it has filter expressions on all consumers that reference the slot.
     * For example, let's assume a producer has two consumers, consumer1 and consumer2:
     *
     * filter(a > 5 and b < 1) -> consumer1
     * filter(a < 8) -> consumer2
     *
     * In this case, the only expression that can be pushed down to the producer is filter(a > 5 or a < 8).
     */
    private LogicalPlan tryToConstructFilter(CascadesContext cascadesContext, CTEId cteId, LogicalPlan child) {
        Set<RelationId> consumerIds = cascadesContext.getCteIdToConsumers().get(cteId).stream()
                .map(LogicalCTEConsumer::getRelationId)
                .collect(Collectors.toSet());
        List<Set<Expression>> filtersAboveEachConsumer = cascadesContext.getConsumerIdToFilters().entrySet().stream()
                .filter(kv -> consumerIds.contains(kv.getKey()))
                .map(Entry::getValue)
                .collect(Collectors.toList());
        Set<Expression> someone = filtersAboveEachConsumer.stream().findFirst().orElse(null);
        if (someone == null) {
            return child;
        }
        int filterSize = cascadesContext.getCteIdToConsumers().get(cteId).size();
        Set<Expression> conjuncts = new HashSet<>();
        for (Expression f : someone) {
            int matchCount = 0;
            Set<SlotReference> slots = f.collect(e -> e instanceof SlotReference);
            Set<Expression> mightBeJoined = new HashSet<>();
            for (Set<Expression> another : filtersAboveEachConsumer) {
                if (another.equals(someone)) {
                    matchCount++;
                    continue;
                }
                Set<Expression> matched = new HashSet<>();
                for (Expression e : another) {
                    Set<SlotReference> otherSlots = e.collect(ae -> ae instanceof SlotReference);
                    if (otherSlots.equals(slots)) {
                        matched.add(e);
                    }
                }
                if (!matched.isEmpty()) {
                    matchCount++;
                }
                mightBeJoined.addAll(matched);
            }
            if (matchCount >= filterSize) {
                mightBeJoined.add(f);
                conjuncts.add(ExpressionUtils.or(mightBeJoined));
            }
        }
        if (!conjuncts.isEmpty()) {
            LogicalPlan filter = new LogicalFilter<>(ImmutableSet.of(ExpressionUtils.and(conjuncts)), child);
            return pushPlanUnderAnchor(filter);
        }
        return child;
    }
}