PushDownMatchPredicateAsVirtualColumn.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Match;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/**
 * Push down MATCH expressions from join/filter predicates as virtual columns on OlapScan.
 *
 * When MATCH appears in a predicate that cannot be pushed below a join (e.g., OR with
 * join-dependent conditions like EXISTS mark or outer join null checks), this rule:
 * 1. Extracts the MATCH expression from the predicate
 * 2. Traces the alias slot back through the Project to find the original column expression
 * 3. Creates a virtual column on the OlapScan with the MATCH on the original expression
 * 4. Replaces the MATCH in the predicate with the virtual column's boolean slot
 *
 * Handles both left-side and right-side Project→OlapScan in joins.
 */
public class PushDownMatchPredicateAsVirtualColumn implements RewriteRuleFactory {

    private boolean canPushDown(LogicalOlapScan scan) {
        return PushDownMatchProjectionAsVirtualColumn.canPushDownMatch(scan);
    }

    @Override
    public List<Rule> buildRules() {
        return ImmutableList.of(
                // Pattern 1L: Filter -> Join -> left(Project -> OlapScan)
                logicalFilter(logicalJoin(
                        logicalProject(logicalOlapScan().when(this::canPushDown)), any()))
                        .when(filter -> hasMatchInSet(filter.getConjuncts()))
                        .then(filter -> handleFilterSide(filter, true, false))
                        .toRule(RuleType.PUSH_DOWN_MATCH_PREDICATE_AS_VIRTUAL_COLUMN),

                // Pattern 1R: Filter -> Join -> right(Project -> OlapScan)
                logicalFilter(logicalJoin(
                        any(), logicalProject(logicalOlapScan().when(this::canPushDown))))
                        .when(filter -> hasMatchInSet(filter.getConjuncts()))
                        .then(filter -> handleFilterSide(filter, false, false))
                        .toRule(RuleType.PUSH_DOWN_MATCH_PREDICATE_AS_VIRTUAL_COLUMN),

                // Pattern 2L: Filter -> Join -> left(Project -> Filter -> OlapScan)
                logicalFilter(logicalJoin(
                        logicalProject(logicalFilter(logicalOlapScan().when(this::canPushDown))), any()))
                        .when(filter -> hasMatchInSet(filter.getConjuncts()))
                        .then(filter -> handleFilterSide(filter, true, true))
                        .toRule(RuleType.PUSH_DOWN_MATCH_PREDICATE_AS_VIRTUAL_COLUMN),

                // Pattern 2R: Filter -> Join -> right(Project -> Filter -> OlapScan)
                logicalFilter(logicalJoin(
                        any(), logicalProject(logicalFilter(logicalOlapScan().when(this::canPushDown)))))
                        .when(filter -> hasMatchInSet(filter.getConjuncts()))
                        .then(filter -> handleFilterSide(filter, false, true))
                        .toRule(RuleType.PUSH_DOWN_MATCH_PREDICATE_AS_VIRTUAL_COLUMN),

                // Pattern 3L: Join(otherPredicates) -> left(Project -> OlapScan)
                logicalJoin(
                        logicalProject(logicalOlapScan().when(this::canPushDown)), any())
                        .when(join -> hasMatchInList(join.getOtherJoinConjuncts()))
                        .then(join -> handleJoinSide(join, true, false))
                        .toRule(RuleType.PUSH_DOWN_MATCH_PREDICATE_AS_VIRTUAL_COLUMN),

                // Pattern 3R: Join(otherPredicates) -> right(Project -> OlapScan)
                logicalJoin(
                        any(), logicalProject(logicalOlapScan().when(this::canPushDown)))
                        .when(join -> hasMatchInList(join.getOtherJoinConjuncts()))
                        .then(join -> handleJoinSide(join, false, false))
                        .toRule(RuleType.PUSH_DOWN_MATCH_PREDICATE_AS_VIRTUAL_COLUMN),

                // Pattern 4L: Join(otherPredicates) -> left(Project -> Filter -> OlapScan)
                logicalJoin(
                        logicalProject(logicalFilter(logicalOlapScan().when(this::canPushDown))), any())
                        .when(join -> hasMatchInList(join.getOtherJoinConjuncts()))
                        .then(join -> handleJoinSide(join, true, true))
                        .toRule(RuleType.PUSH_DOWN_MATCH_PREDICATE_AS_VIRTUAL_COLUMN),

                // Pattern 4R: Join(otherPredicates) -> right(Project -> Filter -> OlapScan)
                logicalJoin(
                        any(), logicalProject(logicalFilter(logicalOlapScan().when(this::canPushDown))))
                        .when(join -> hasMatchInList(join.getOtherJoinConjuncts()))
                        .then(join -> handleJoinSide(join, false, true))
                        .toRule(RuleType.PUSH_DOWN_MATCH_PREDICATE_AS_VIRTUAL_COLUMN)
        );
    }

    private Plan handleFilterSide(LogicalFilter<?> filter, boolean isLeft, boolean hasInnerFilter) {
        LogicalJoin<?, ?> join = (LogicalJoin<?, ?>) filter.child();
        Plan side = isLeft ? join.left() : join.right();
        LogicalProject<?> project = (LogicalProject<?>) side;

        LogicalOlapScan scan;
        ScanRebuilder rebuilder;
        if (hasInnerFilter) {
            LogicalFilter<?> scanFilter = (LogicalFilter<?>) project.child();
            scan = (LogicalOlapScan) scanFilter.child();
            rebuilder = newScan -> scanFilter.withChildren(ImmutableList.of(newScan));
        } else {
            scan = (LogicalOlapScan) project.child();
            rebuilder = newScan -> newScan;
        }

        Set<Slot> projectOutputSlots = ImmutableSet.copyOf(project.getOutput());
        List<Expression> predicateList = new ArrayList<>(filter.getConjuncts());
        PushDownResult result = buildVirtualColumnsFromList(predicateList, project, scan, projectOutputSlots);
        if (result == null) {
            return null;
        }

        LogicalProject<?> newProject = (LogicalProject<?>) project.withProjectsAndChild(
                result.newProjections, rebuilder.rebuild(result.newScan));
        Plan newJoin = isLeft
                ? join.withChildren(newProject, join.right())
                : join.withChildren(join.left(), newProject);
        return filter.withConjunctsAndChild(ImmutableSet.copyOf(result.newPredicateList), newJoin);
    }

    private Plan handleJoinSide(LogicalJoin<?, ?> join, boolean isLeft, boolean hasInnerFilter) {
        Plan side = isLeft ? join.left() : join.right();
        LogicalProject<?> project = (LogicalProject<?>) side;

        LogicalOlapScan scan;
        ScanRebuilder rebuilder;
        if (hasInnerFilter) {
            LogicalFilter<?> scanFilter = (LogicalFilter<?>) project.child();
            scan = (LogicalOlapScan) scanFilter.child();
            rebuilder = newScan -> scanFilter.withChildren(ImmutableList.of(newScan));
        } else {
            scan = (LogicalOlapScan) project.child();
            rebuilder = newScan -> newScan;
        }

        Set<Slot> projectOutputSlots = ImmutableSet.copyOf(project.getOutput());
        List<Expression> otherConjuncts = join.getOtherJoinConjuncts();
        PushDownResult result = buildVirtualColumnsFromList(otherConjuncts, project, scan, projectOutputSlots);
        if (result == null) {
            return null;
        }

        LogicalProject<?> newProject = (LogicalProject<?>) project.withProjectsAndChild(
                result.newProjections, rebuilder.rebuild(result.newScan));
        Plan newLeft = isLeft ? newProject : join.left();
        Plan newRight = isLeft ? join.right() : newProject;
        return join.withJoinConjuncts(join.getHashJoinConjuncts(),
                result.newPredicateList, join.getJoinReorderContext())
                .withChildren(newLeft, newRight);
    }

    private interface ScanRebuilder {
        Plan rebuild(LogicalOlapScan newScan);
    }

    private boolean hasMatchInSet(Set<Expression> conjuncts) {
        return conjuncts.stream().anyMatch(this::containsMatch);
    }

    private boolean hasMatchInList(List<Expression> exprs) {
        return exprs.stream().anyMatch(this::containsMatch);
    }

    private boolean containsMatch(Expression expr) {
        if (expr instanceof Match) {
            return true;
        }
        for (Expression child : expr.children()) {
            if (containsMatch(child)) {
                return true;
            }
        }
        return false;
    }

    private PushDownResult buildVirtualColumnsFromList(List<Expression> predicates,
            LogicalProject<?> project, LogicalOlapScan scan, Set<Slot> projectOutputSlots) {
        Map<Match, Alias> matchToVirtualColumn = new HashMap<>();
        Map<Match, Slot> matchToVirtualSlot = new HashMap<>();

        for (Expression predicate : predicates) {
            collectMatchesNeedingPushDown(predicate, project, projectOutputSlots,
                    matchToVirtualColumn, matchToVirtualSlot);
        }

        if (matchToVirtualColumn.isEmpty()) {
            return null;
        }

        LogicalOlapScan newScan = scan.appendVirtualColumns(
                new ArrayList<>(matchToVirtualColumn.values()));

        List<NamedExpression> newProjections = new ArrayList<>(project.getProjects());
        for (Alias vcAlias : matchToVirtualColumn.values()) {
            newProjections.add(vcAlias.toSlot());
        }

        List<Expression> newPredicateList = new ArrayList<>();
        for (Expression predicate : predicates) {
            newPredicateList.add(replaceMatch(predicate, matchToVirtualSlot));
        }

        return new PushDownResult(newScan, newProjections, newPredicateList);
    }

    private void collectMatchesNeedingPushDown(Expression expr,
            LogicalProject<?> project, Set<Slot> projectOutputSlots,
            Map<Match, Alias> matchToVirtualColumn, Map<Match, Slot> matchToVirtualSlot) {
        if (expr instanceof Match) {
            Match match = (Match) expr;
            Set<Slot> inputSlots = match.left().getInputSlots();
            List<SlotReference> matchSlots = inputSlots.stream()
                    .filter(SlotReference.class::isInstance)
                    .map(SlotReference.class::cast)
                    .collect(Collectors.toList());
            if (matchSlots.isEmpty()) {
                return;
            }

            // All slots must come from the project side
            if (!matchSlots.stream().allMatch(projectOutputSlots::contains)) {
                return;
            }

            // If all slots have metadata, no need to push down
            boolean allHaveMetadata = matchSlots.stream()
                    .allMatch(s -> s.getOriginalColumn().isPresent() && s.getOriginalTable().isPresent());
            if (allHaveMetadata) {
                return;
            }

            // Use the first slot to trace back through the project
            SlotReference matchSlot = matchSlots.get(0);
            Expression sourceExpr = findSourceExpression(matchSlot, project);
            if (sourceExpr == null) {
                return;
            }

            Match newMatch = (Match) match.withChildren(
                    ImmutableList.of(sourceExpr, match.right()));
            Alias vcAlias = new Alias(newMatch);
            Slot vcSlot = vcAlias.toSlot();

            matchToVirtualColumn.put(match, vcAlias);
            matchToVirtualSlot.put(match, vcSlot);
            return;
        }

        for (Expression child : expr.children()) {
            collectMatchesNeedingPushDown(child, project, projectOutputSlots,
                    matchToVirtualColumn, matchToVirtualSlot);
        }
    }

    private Expression findSourceExpression(SlotReference slot, LogicalProject<?> project) {
        for (NamedExpression ne : project.getProjects()) {
            if (ne.getExprId().equals(slot.getExprId())) {
                if (ne instanceof Alias) {
                    return ((Alias) ne).child();
                } else if (ne instanceof SlotReference) {
                    return ne;
                }
            }
        }
        return null;
    }

    private Expression replaceMatch(Expression expr, Map<Match, Slot> matchToSlot) {
        if (expr instanceof Match && matchToSlot.containsKey(expr)) {
            return matchToSlot.get(expr);
        }

        boolean changed = false;
        List<Expression> newChildren = new ArrayList<>();
        for (Expression child : expr.children()) {
            Expression newChild = replaceMatch(child, matchToSlot);
            if (newChild != child) {
                changed = true;
            }
            newChildren.add(newChild);
        }

        if (!changed) {
            return expr;
        }
        return expr.withChildren(newChildren);
    }

    private static class PushDownResult {
        final LogicalOlapScan newScan;
        final List<NamedExpression> newProjections;
        final List<Expression> newPredicateList;

        PushDownResult(LogicalOlapScan newScan, List<NamedExpression> newProjections,
                List<Expression> newPredicateList) {
            this.newScan = newScan;
            this.newProjections = newProjections;
            this.newPredicateList = newPredicateList;
        }
    }
}