PushDownVirtualColumnsIntoOlapScan.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.catalog.KeysType;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.Match;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.scalar.DecodeAsVarchar;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsBigInt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsInt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsLargeInt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsSmallInt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.IsIpAddressInRange;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MultiMatch;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MultiMatchAny;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Maps;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

/**
 * Extract virtual columns from filter and push down them into olap scan.
 * This rule can extract:
 * 1. Common repeated sub-expressions across multiple conjuncts to eliminate redundant computation
 *
 * Example transformation:
 * Before:
 * Project[a, b, c]
 * └── Filter[func(x, y) > 10 AND func(x, y) < 100 AND func(z, w) = func(x, y)]
 *     └── OlapScan[table]
 *
 * After:
 * Project[a, b, c]
 * └── Filter[v_func_1 > 10 AND v_func_1 < 100 AND v_func_2 = v_func_1]
 *     └── OlapScan[table, virtual_columns=[func(x, y) as v_func_1, func(z, w) as v_func_2]]
 *
 * Benefits:
 * - Eliminates redundant computation of repeated expressions
 * - Can leverage vectorization and SIMD optimizations at scan level
 * - Reduces CPU usage in upper operators
 *
 * BLACKLIST STRATEGY:
 * To avoid reverse optimization (preventing more important optimizations), this rule implements
 * a blacklist strategy that skips certain types of expressions:
 *
 * 1. Index Pushdown Functions: Functions like is_ip_address_in_range(), multi_match(), match_*
 *    can be pushed down to storage engine as index operations. Virtual column optimization would
 *    prevent this index pushdown optimization.
 *
 * 2. ColumnPredicate Expressions: Comparison predicates (>, <, =, IN, IS NULL) can be converted
 *    to ColumnPredicate objects for efficient filtering in BE. Virtual columns would lose this
 *    optimization opportunity.
 *
 * 3. CAST Expressions: CAST operations are lightweight and creating virtual columns for them
 *    may not provide significant benefit while adding complexity.
 *
 * 4. Lambda-containing Expressions: Expressions with lambda functions have complex evaluation
 *    contexts that make virtual column optimization problematic.
 */
public class PushDownVirtualColumnsIntoOlapScan implements RewriteRuleFactory {

    private static final Logger LOG = LogManager.getLogger(PushDownVirtualColumnsIntoOlapScan.class);

    // Configuration constants for sub-expression extraction
    private static final int MIN_OCCURRENCE_COUNT = 2; // Minimum times an expression must appear to be considered
    private static final int MIN_EXPRESSION_DEPTH = 2; // Minimum depth of expression tree to be beneficial
    private static final int MAX_VIRTUAL_COLUMNS = 5; // Maximum number of virtual columns to avoid explosion

    // Logger for debugging
    private static final Logger logger = LogManager.getLogger(PushDownVirtualColumnsIntoOlapScan.class);

    @Override
    public List<Rule> buildRules() {
        return ImmutableList.of(
                logicalProject(logicalFilter(logicalOlapScan()
                        .when(s -> {
                            boolean dupTblOrMOW = s.getTable().getKeysType() == KeysType.DUP_KEYS
                                    || s.getTable().getTableProperty().getEnableUniqueKeyMergeOnWrite();
                            return dupTblOrMOW && s.getVirtualColumns().isEmpty();
                        })))
                        .then(project -> {
                            LogicalFilter<LogicalOlapScan> filter = project.child();
                            LogicalOlapScan scan = filter.child();
                            return pushDown(filter, scan, Optional.of(project));
                        }).toRule(RuleType.PUSH_DOWN_VIRTUAL_COLUMNS_INTO_OLAP_SCAN),
                logicalFilter(logicalOlapScan()
                        .when(s -> {
                            boolean dupTblOrMOW = s.getTable().getKeysType() == KeysType.DUP_KEYS
                                    || s.getTable().getTableProperty().getEnableUniqueKeyMergeOnWrite();
                            return dupTblOrMOW && s.getVirtualColumns().isEmpty();
                        }))
                        .then(filter -> {
                            LogicalOlapScan scan = filter.child();
                            return pushDown(filter, scan, Optional.empty());
                        }).toRule(RuleType.PUSH_DOWN_VIRTUAL_COLUMNS_INTO_OLAP_SCAN)
        );
    }

    private Plan pushDown(LogicalFilter<LogicalOlapScan> filter, LogicalOlapScan logicalOlapScan,
            Optional<LogicalProject<?>> optionalProject) {
        // 1. extract repeated sub-expressions from filter conjuncts
        // 2. generate virtual columns and add them to scan
        // 3. replace filter and project

        Map<Expression, Expression> replaceMap = Maps.newHashMap();
        ImmutableList.Builder<NamedExpression> virtualColumnsBuilder = ImmutableList.builder();

        // Extract repeated sub-expressions
        extractRepeatedSubExpressions(filter, optionalProject, replaceMap, virtualColumnsBuilder);

        if (replaceMap.isEmpty()) {
            return null;
        }

        if (LOG.isDebugEnabled()) {
            LOG.debug("PushDownVirtualColumnsIntoOlapScan: Created {} virtual columns for expressions: {}",
                    replaceMap.size(), replaceMap.keySet());
        }

        // Create new scan with virtual columns
        logicalOlapScan = logicalOlapScan.withVirtualColumns(virtualColumnsBuilder.build());

        // Replace expressions in filter and project
        Set<Expression> conjuncts = ExpressionUtils.replace(filter.getConjuncts(), replaceMap);
        Plan plan = filter.withConjunctsAndChild(conjuncts, logicalOlapScan);

        if (optionalProject.isPresent()) {
            LogicalProject<?> project = optionalProject.get();
            List<NamedExpression> projections = ExpressionUtils.replace(
                    (List) project.getProjects(), replaceMap);
            plan = project.withProjectsAndChild(projections, plan);
        } else {
            plan = new LogicalProject<>((List) filter.getOutput(), plan);
        }
        return plan;
    }

    /**
     * Extract repeated sub-expressions from filter conjuncts and project expressions
     */
    private void extractRepeatedSubExpressions(LogicalFilter<LogicalOlapScan> filter,
            Optional<LogicalProject<?>> optionalProject,
            Map<Expression, Expression> replaceMap,
            ImmutableList.Builder<NamedExpression> virtualColumnsBuilder) {

        // Collect all expressions from filter and project
        Set<Expression> allExpressions = new HashSet<>();
        for (Expression conjunct : filter.getConjuncts()) {
            allExpressions.add(conjunct);
        }
        if (optionalProject.isPresent()) {
            LogicalProject<?> project = optionalProject.get();
            for (NamedExpression projection : project.getProjects()) {
                allExpressions.add(projection);
            }
        }

        // Count occurrences of each sub-expression
        Map<Expression, Integer> expressionCounts = new HashMap<>();

        for (Expression expr : allExpressions) {
            collectSubExpressions(expr, expressionCounts);
        }

        // Find expressions that occur more than once and are beneficial to push down
        // Sort by cost-benefit ratio to prioritize the most beneficial expressions
        expressionCounts.entrySet().stream()
                .filter(entry -> entry.getValue() >= MIN_OCCURRENCE_COUNT)
                .filter(entry -> !replaceMap.containsKey(entry.getKey()))
                .sorted((e1, e2) -> {
                    // Sort by benefit: (occurrence_count - 1) * expression_complexity
                    int benefit1 = (e1.getValue() - 1) * getExpressionComplexity(e1.getKey());
                    int benefit2 = (e2.getValue() - 1) * getExpressionComplexity(e2.getKey());
                    return Integer.compare(benefit2, benefit1); // descending order
                })
                .limit(MAX_VIRTUAL_COLUMNS - replaceMap.size()) // Limit total virtual columns
                .forEach(entry -> {
                    Expression expr = entry.getKey();
                    Alias alias = new Alias(expr);
                    replaceMap.put(expr, alias.toSlot());
                    virtualColumnsBuilder.add(alias);
                });

        // Logging for debugging
        if (LOG.isDebugEnabled()) {
            logger.debug("Extracted virtual columns: {}", virtualColumnsBuilder.build());
        }
    }

    /**
     * Recursively collect all sub-expressions and count their occurrences
     */
    private void collectSubExpressions(Expression expr, Map<Expression, Integer> expressionCounts) {
        collectSubExpressions(expr, expressionCounts, false);
    }

    /**
     * Recursively collect all sub-expressions and count their occurrences
     * @param expr the expression to analyze
     * @param expressionCounts map to store expression occurrence counts
     * @param insideLambda whether we are currently inside a lambda function
     */
    private void collectSubExpressions(Expression expr, Map<Expression, Integer> expressionCounts,
                                    boolean insideLambda) {
        // Check if we should skip this expression and how to handle it
        SkipResult skipResult = shouldSkipExpression(expr, insideLambda);

        if (skipResult.shouldTerminate()) {
            // Examples: x (slot), 10 (constant), expressions inside lambda functions
            // These expressions are completely skipped - no counting, no recursion
            return;
        }

        if (skipResult.shouldSkipCounting() || skipResult.isNotBeneficial()) {
            // Examples for SKIP_COUNTING: CAST(x AS VARCHAR)
            // Examples for SKIP_NOT_BENEFICIAL:
            //   - encode_as_bigint(x), decode_as_varchar(x)
            //   - x > 10, x IN (1,2,3), x IS NULL (ColumnPredicate convertible)
            //   - is_ip_address_in_range(ip, '192.168.1.0/24'), multi_match(text, 'query') (index pushdown)
            //   - expressions containing lambda functions
            // These expressions are not counted but we continue processing their children
            for (Expression child : expr.children()) {
                collectSubExpressions(child, expressionCounts, insideLambda);
            }
            return;
        }

        // CONTINUE case: Examples like x + y, func(a, b), (x + y) * z
        // Only count expressions that meet minimum complexity requirements
        if (expr.getDepth() >= MIN_EXPRESSION_DEPTH && expr.children().size() > 0) {
            expressionCounts.put(expr, expressionCounts.getOrDefault(expr, 0) + 1);
        }

        // Recursively process children
        for (Expression child : expr.children()) {
            // Check if we're entering a lambda function
            boolean enteringLambda = insideLambda || (expr instanceof Lambda);
            collectSubExpressions(child, expressionCounts, enteringLambda);
        }
    }

    /**
     * Determine how to handle an expression during sub-expression collection
     * This method consolidates ALL skip logic in one place
     * @param expr the expression to check
     * @param insideLambda whether we are currently inside a lambda function
     * @return SkipResult indicating how to handle this expression
     */
    private SkipResult shouldSkipExpression(Expression expr, boolean insideLambda) {
        // Skip simple slots and literals as they don't benefit from being pushed down
        if (expr instanceof Slot || expr.isConstant()) {
            return SkipResult.TERMINATE;
        }

        // Skip expressions inside lambda functions - they shouldn't be optimized
        if (insideLambda) {
            return SkipResult.TERMINATE;
        }

        // Skip CAST expressions - they shouldn't be optimized as common sub-expressions
        // but we still need to process their children
        if (expr instanceof Cast) {
            return SkipResult.SKIP_COUNTING;
        }

        // Skip expressions with decode_as_varchar or encode_as_bigint as root
        if (expr instanceof DecodeAsVarchar || expr instanceof EncodeAsBigInt || expr instanceof EncodeAsInt
                || expr instanceof EncodeAsLargeInt || expr instanceof EncodeAsSmallInt) {
            return SkipResult.SKIP_NOT_BENEFICIAL;
        }

        // Skip expressions that contain lambda functions anywhere in the tree
        if (containsLambdaFunction(expr)) {
            return SkipResult.SKIP_NOT_BENEFICIAL;
        }

        // Skip expressions that can be converted to ColumnPredicate or can use index
        // This is the key blacklist logic to avoid reverse optimization
        if (canConvertToColumnPredicate(expr) || containsIndexPushdownFunction(expr)) {
            return SkipResult.SKIP_NOT_BENEFICIAL;
        }

        // Continue normal processing
        return SkipResult.CONTINUE;
    }

    /**
     * Check if an expression contains lambda functions
     */
    private boolean containsLambdaFunction(Expression expr) {
        if (expr instanceof Lambda) {
            return true;
        }

        for (Expression child : expr.children()) {
            if (containsLambdaFunction(child)) {
                return true;
            }
        }

        return false;
    }

    /**
     * Result type for expression skip decisions
     */
    private enum SkipResult {
        // Process normally (count and recurse)
        // Examples: x + y, func(a, b), (x + y) * z - beneficial arithmetic/function expressions
        CONTINUE,

        // Skip counting but continue processing children (for CAST expressions)
        // Examples: CAST(x AS VARCHAR), CAST(date_col AS STRING)
        // We don't optimize CAST itself but may optimize its children
        SKIP_COUNTING,

        // Skip counting but continue processing children (expressions not beneficial for optimization)
        // Examples:
        //   - encode_as_bigint(x), decode_as_varchar(x) - encoding/decoding functions
        //   - x > 10, x IN (1,2,3) - ColumnPredicate convertible expressions
        //   - is_ip_address_in_range(ip, '192.168.1.0/24') - index pushdown functions
        //   - expressions containing lambda functions
        SKIP_NOT_BENEFICIAL,

        // Stop processing entirely (don't count, don't recurse)
        // Examples: x (slot), 10 (constant), expressions inside lambda functions
        TERMINATE;

        public boolean shouldTerminate() {
            return this == TERMINATE;
        }

        public boolean shouldSkipCounting() {
            return this == SKIP_COUNTING;
        }

        public boolean isNotBeneficial() {
            return this == SKIP_NOT_BENEFICIAL;
        }
    }

    /**
     * Calculate the complexity/cost of an expression for cost-benefit analysis
     */
    private int getExpressionComplexity(Expression expr) {
        // Use expression depth and width as a simple complexity metric
        // More sophisticated metrics could consider function call costs, etc.
        return expr.getDepth() * expr.getWidth();
    }

    /**
     * Check if an expression can be converted to a ColumnPredicate
     * ColumnPredicate types include: EQ, NE, LT, LE, GT, GE, IN_LIST, NOT_IN_LIST, IS_NULL, IS_NOT_NULL, etc.
     */
    private boolean canConvertToColumnPredicate(Expression expr) {
        // Basic comparison predicates that can be converted to ColumnPredicate
        if (expr instanceof ComparisonPredicate) {
            // EQ, NE, LT, LE, GT, GE
            return true;
        }

        // IN and NOT IN predicates
        if (expr instanceof InPredicate) {
            return true;
        }

        // IS NULL and IS NOT NULL predicates
        if (expr instanceof IsNull) {
            return true;
        }

        // Note: Other predicates like LIKE, MATCH, etc. might also be convertible
        // but they are handled separately in containsIndexPushdownFunction
        return false;
    }

    /**
     * Check if an expression contains functions that can be pushed down to index
     */
    private boolean containsIndexPushdownFunction(Expression expr) {
        return expr.anyMatch(node -> isIndexPushdownFunction((Expression) node));
    }

    /**
     * Check if a single expression is an index pushdown function
     */
    private boolean isIndexPushdownFunction(Expression expr) {
        // Functions that implement evaluate_inverted_index and can be pushed down to index

        // IP address range functions
        if (expr instanceof IsIpAddressInRange) {
            return true;
        }

        // Multi-match functions or Match predicate
        if (expr instanceof MultiMatch || expr instanceof MultiMatchAny || expr instanceof Match) {
            return true;
        } else {
            return false;
        }
    }

    /**
     * Get function name from expression if it's a function call
     */
    private String getFunctionName(Expression expr) {
        // Try to get function name from expression
        // This is a simplified approach - in practice, you might need more robust name extraction
        if (expr instanceof NamedExpression) {
            return ((NamedExpression) expr).getName();
        }
        return null;
    }
}