ReplaceExpressionByChildOutput.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.analysis;

import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Maps;

import java.util.List;
import java.util.Map;

/**
 * replace.
 */
public class ReplaceExpressionByChildOutput implements AnalysisRuleFactory {
    @Override
    public List<Rule> buildRules() {
        return ImmutableList.<Rule>builder()
                .add(RuleType.REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT.build(
                        logicalSort(logicalProject()).then(sort -> {
                            LogicalProject<Plan> project = sort.child();
                            Map<Expression, Slot> sMap = buildOutputAliasMap(project.getProjects());
                            return replaceSortExpression(sort, sMap);
                        })
                ))
                .add(RuleType.REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT.build(
                        logicalSort(logicalAggregate()).then(sort -> {
                            LogicalAggregate<Plan> aggregate = sort.child();
                            Map<Expression, Slot> sMap = buildOutputAliasMap(aggregate.getOutputExpressions());
                            return replaceSortExpression(sort, sMap);
                        })
                )).add(RuleType.REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT.build(
                        logicalSort(logicalHaving(logicalAggregate())).then(sort -> {
                            LogicalAggregate<Plan> aggregate = sort.child().child();
                            Map<Expression, Slot> sMap = buildOutputAliasMap(aggregate.getOutputExpressions());
                            return replaceSortExpression(sort, sMap);
                        })
                ))
                .build();
    }

    private Map<Expression, Slot> buildOutputAliasMap(List<NamedExpression> output) {
        Map<Expression, Slot> sMap = Maps.newHashMapWithExpectedSize(output.size());
        for (NamedExpression expr : output) {
            if (expr instanceof Alias) {
                Alias alias = (Alias) expr;
                sMap.put(alias.child(), alias.toSlot());
            }
        }
        return sMap;
    }

    private LogicalPlan replaceSortExpression(LogicalSort<? extends LogicalPlan> sort, Map<Expression, Slot> sMap) {
        List<OrderKey> orderKeys = sort.getOrderKeys();

        boolean changed = false;
        ImmutableList.Builder<OrderKey> newKeys = ImmutableList.builderWithExpectedSize(orderKeys.size());
        for (OrderKey k : orderKeys) {
            Expression newExpr = ExpressionUtils.replace(k.getExpr(), sMap);
            if (newExpr != k.getExpr()) {
                changed = true;
            }
            newKeys.add(new OrderKey(newExpr, k.isAsc(), k.isNullFirst()));
        }

        return changed ? new LogicalSort<>(newKeys.build(), sort.child()) : sort;
    }
}