PushDownScoreTopNIntoOlapScan.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.Match;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Score;
import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
import org.apache.doris.nereids.trees.expressions.literal.FloatLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.ScoreRangeInfo;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.thrift.TExprOpcode;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

/**
 * Push down score function into olap scan node.
 * It will push down score function as a virtual column, and also push down the
 * topN info.
 *
 * Pattern:
 * logicalTopN(logicalProject(logicalFilter(logicalOlapScan)))
 *
 * Requirements:
 * 1. The TopN node has exactly one ordering expression.
 * 2. The ordering expression in TopN must be a slot reference that refers to a
 * Score function.
 * 3. The Filter node must contain at least one Match function.
 *
 * Additionally, this rule now supports score range predicates in WHERE clause:
 * - score() > X, score() >= X, score() < X, score() <= X
 * These predicates are extracted and pushed down to the scan node.
 *
 * Example:
 * Before:
 * SELECT score() as score FROM table WHERE text_col MATCH 'query' AND score() > 0.5
 * ORDER BY score DESC LIMIT 10
 *
 * After:
 * The Score function is pushed down into the OlapScan node as a virtual column,
 * and the TopN information (order by, limit) is also pushed down to be used by
 * the storage engine. The score range predicate is also extracted.
 */
public class PushDownScoreTopNIntoOlapScan implements RewriteRuleFactory {
    private static final Logger LOG = LogManager.getLogger(PushDownScoreTopNIntoOlapScan.class);

    @Override
    public List<Rule> buildRules() {
        return ImmutableList.of(
                logicalTopN(logicalProject(logicalFilter(logicalOlapScan())))
                        .then(topN -> {
                            LogicalProject<LogicalFilter<LogicalOlapScan>> project = topN.child();
                            LogicalFilter<LogicalOlapScan> filter = project.child();
                            LogicalOlapScan scan = filter.child();
                            return pushDown(topN, project, filter, scan);
                        }).toRule(RuleType.PUSH_DOWN_SCORE_TOPN_INTO_OLAP_SCAN));
    }

    private Plan pushDown(
            LogicalTopN<LogicalProject<LogicalFilter<LogicalOlapScan>>> topN,
            LogicalProject<LogicalFilter<LogicalOlapScan>> project,
            LogicalFilter<LogicalOlapScan> filter,
            LogicalOlapScan scan) {
        // 1. Requirement: Project must contain a score() function.
        boolean hasScoreFunction = project.getProjects().stream()
                .anyMatch(projection -> {
                    if (projection instanceof Alias) {
                        return ((Alias) projection).child() instanceof Score;
                    }
                    return false;
                });
        if (!hasScoreFunction) {
            return null;
        }

        // 2. Requirement: WHERE clause must contain a MATCH function.
        boolean hasMatchPredicate = filter.getConjuncts().stream()
                .anyMatch(conjunct -> !conjunct.collect(e -> e instanceof Match).isEmpty());
        if (!hasMatchPredicate) {
            throw new AnalysisException(
                    "WHERE clause must contain at least one MATCH function"
                            + " for score() push down optimization");
        }

        // 3. Check for score() predicates in WHERE clause and extract score range info
        List<Expression> scorePredicates = filter.getConjuncts().stream()
                .filter(conjunct -> !conjunct.collect(e -> e instanceof Score).isEmpty())
                .collect(ImmutableList.toImmutableList());

        Optional<ScoreRangeInfo> scoreRangeInfo = Optional.empty();
        Expression extractedScorePredicate = null;
        if (!scorePredicates.isEmpty()) {
            if (scorePredicates.size() > 1) {
                throw new AnalysisException(
                        "Only one score() range predicate is supported in WHERE clause. "
                                + "Found " + scorePredicates.size() + " predicates: " + scorePredicates);
            }
            Expression predicate = scorePredicates.get(0);
            scoreRangeInfo = extractScoreRangeInfo(predicate);
            if (!scoreRangeInfo.isPresent()) {
                throw new AnalysisException(
                        "score() predicate in WHERE clause must be in the form of 'score() > literal' "
                                + "or 'score() >= literal' (min_score semantics). "
                                + "Operators <, <=, and = are not supported.");
            }
            extractedScorePredicate = predicate;
        }

        // 4. Requirement: TopN must have exactly one ordering expression.
        if (topN.getOrderKeys().size() != 1) {
            throw new AnalysisException(
                    "TopN must have exactly one ordering expression for score() push down optimization");
        }

        // 5. Requirement: ORDER BY expression must refer to a SELECT score() alias.
        Expression orderKey = topN.getOrderKeys().get(0).getExpr();
        if (!(orderKey instanceof SlotReference)) {
            throw new AnalysisException(
                    "ORDER BY expression must be a slot reference (not a function call or complex expression)"
                            + " for score() push down optimization");
        }

        SlotReference orderSlot = (SlotReference) orderKey;
        Expression scoreExpr = null;
        Alias scoreAlias = null;

        // Find the 'score()' expression in the project list that matches the order by
        // slot.
        for (NamedExpression projection : project.getProjects()) {
            if (projection.toSlot().equals(orderSlot) && projection instanceof Alias) {
                Expression childExpr = ((Alias) projection).child();
                if (childExpr instanceof Score) {
                    scoreExpr = childExpr;
                    scoreAlias = (Alias) projection;
                    break;
                }
            }
        }

        if (scoreAlias == null) {
            throw new AnalysisException(
                    "ORDER BY expression must reference a score() function from SELECT clause"
                            + " for push down optimization");
        }

        // All conditions met, perform the push down.
        // This is the core action: push score() as a virtual column and also push the
        // topN info.
        Plan newScan = scan.withVirtualColumnsAndTopN(ImmutableList.of(scoreAlias),
                ImmutableList.of(), Optional.empty(),
                topN.getOrderKeys(), Optional.of(topN.getLimit() + topN.getOffset()),
                scoreRangeInfo);

        // Rebuild the plan tree above the new scan.
        // We need to replace the original score() function with a reference to the new
        // virtual column slot.
        Map<Expression, Expression> replaceMap = Maps.newHashMap();
        replaceMap.put(scoreExpr, scoreAlias.toSlot());
        replaceMap.put(scoreAlias, scoreAlias.toSlot());

        // If we extracted a score predicate, remove it from the filter
        // as it will be pushed down to the scan node
        Set<Expression> newConjuncts;
        if (extractedScorePredicate != null) {
            final Expression predicateToRemove = extractedScorePredicate;
            newConjuncts = filter.getConjuncts().stream()
                    .filter(c -> !c.equals(predicateToRemove))
                    .collect(ImmutableSet.toImmutableSet());
        } else {
            newConjuncts = filter.getConjuncts();
        }

        // The filter node remains, as the MATCH predicate is still needed.
        Plan newFilter;
        if (newConjuncts.isEmpty()) {
            newFilter = newScan;
        } else {
            newFilter = filter.withConjunctsAndChild(newConjuncts, newScan);
        }

        // Rebuild project list with the replaced expressions.
        List<NamedExpression> newProjections = ExpressionUtils
                .replaceNamedExpressions(project.getProjects(), replaceMap);
        Plan newProject = project.withProjectsAndChild(newProjections, newFilter);

        // Rebuild the TopN node on top of the new project.
        return topN.withChildren(newProject);
    }

    /**
     * Extract score range info from a single score predicate.
     * Only supports min_score semantics (similar to Elasticsearch):
     * - score() > X or score() >= X
     * - Reversed patterns: X < score() or X <= score()
     *
     * Note: < and <= are NOT supported because max_score filtering is rarely needed.
     * Note: EqualTo (=) is NOT supported.
     */
    private Optional<ScoreRangeInfo> extractScoreRangeInfo(Expression predicate) {
        if (!(predicate instanceof ComparisonPredicate)) {
            if (!predicate.collect(e -> e instanceof Score).isEmpty()) {
                throw new AnalysisException(
                        "score() predicate must be a top-level AND condition in WHERE clause. "
                                + "Nesting score() inside OR or other compound expressions is not supported. "
                                + "Invalid expression: " + predicate.toSql());
            }
            return Optional.empty();
        }

        ComparisonPredicate comp = (ComparisonPredicate) predicate;
        Expression left = comp.left();
        Expression right = comp.right();

        if (isScoreExpression(left) && isNumericLiteral(right)) {
            TExprOpcode op = getMinScoreOpcode(comp);
            if (op != null) {
                return Optional.of(new ScoreRangeInfo(op, extractNumericValue(right)));
            }
        }

        if (isScoreExpression(right) && isNumericLiteral(left)) {
            TExprOpcode op = getReversedMinScoreOpcode(comp);
            if (op != null) {
                return Optional.of(new ScoreRangeInfo(op, extractNumericValue(left)));
            }
        }

        return Optional.empty();
    }

    /**
     * Check if the expression is a Score function, possibly wrapped in Cast expressions.
     * The optimizer may wrap score() in Cast for type coercion (e.g., score() >= 4.0 may become
     * CAST(score() AS DECIMAL) >= 4.0).
     */
    private boolean isScoreExpression(Expression expr) {
        if (expr instanceof Score) {
            return true;
        }
        if (expr instanceof Cast) {
            return isScoreExpression(((Cast) expr).child());
        }
        return false;
    }

    private boolean isNumericLiteral(Expression expr) {
        return expr instanceof DoubleLiteral
                || expr instanceof FloatLiteral
                || expr instanceof IntegerLikeLiteral
                || expr instanceof DecimalV3Literal;
    }

    private double extractNumericValue(Expression expr) {
        if (expr instanceof DoubleLiteral) {
            return ((DoubleLiteral) expr).getValue();
        } else if (expr instanceof FloatLiteral) {
            return ((FloatLiteral) expr).getValue();
        } else if (expr instanceof IntegerLikeLiteral) {
            return ((IntegerLikeLiteral) expr).getLongValue();
        } else if (expr instanceof DecimalV3Literal) {
            return ((DecimalV3Literal) expr).getDouble();
        }
        throw new IllegalArgumentException("Not a numeric literal: " + expr);
    }

    /**
     * Get opcode for min_score patterns: score() > X or score() >= X
     * Returns null for unsupported operators (< and <=)
     */
    private TExprOpcode getMinScoreOpcode(ComparisonPredicate comp) {
        if (comp instanceof GreaterThan) {
            return TExprOpcode.GT;
        } else if (comp instanceof GreaterThanEqual) {
            return TExprOpcode.GE;
        }
        return null;
    }

    /**
     * Get the reversed opcode for min_score patterns like "0.5 < score()" (equivalent to "score() > 0.5")
     * Returns null for unsupported operators
     */
    private TExprOpcode getReversedMinScoreOpcode(ComparisonPredicate comp) {
        if (comp instanceof LessThan) {
            return TExprOpcode.GT;
        } else if (comp instanceof LessThanEqual) {
            return TExprOpcode.GE;
        }
        return null;
    }
}