PullUpProjectBetweenTopNAndAgg.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
import org.apache.doris.qe.ConnectContext;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 *
 * try to reduce shuffle cost of topN operator, used to optimize plan after applying Compress_materialize
 *
 * topn(orderKey=[a])
 *   --> project(a+1 as x, a+2 as y, a)
 *      --> any(output(a))
 * =>
 * project(a+1 as x, a+2 as y, a)
 *  --> topn(orderKey=[a])
 *    --> any(output(a))
 *
 */
public class PullUpProjectBetweenTopNAndAgg extends OneRewriteRuleFactory {
    public static final Logger LOG = LogManager.getLogger(PullUpProjectBetweenTopNAndAgg.class);

    @Override
    public Rule build() {
        return logicalTopN(logicalProject(logicalAggregate()))
                .when(topN -> ConnectContext.get() != null
                        && ConnectContext.get().getSessionVariable().enableCompressMaterialize)
                .then(topN -> adjust(topN)).toRule(RuleType.ADJUST_TOPN_PROJECT);
    }

    private Plan adjust(LogicalTopN<? extends Plan> topN) {
        LogicalProject<Plan> project = (LogicalProject<Plan>) topN.child();
        Set<Slot> projectInputSlots = project.getInputSlots();
        Map<SlotReference, SlotReference> keyAsKey = new HashMap<>();
        for (NamedExpression proj : project.getProjects()) {
            if (proj instanceof Alias && ((Alias) proj).child(0) instanceof SlotReference) {
                keyAsKey.put((SlotReference) ((Alias) proj).toSlot(), (SlotReference) ((Alias) proj).child());
            }
        }
        boolean match = true;
        List<OrderKey> newOrderKeys = new ArrayList<>();
        for (OrderKey orderKey : topN.getOrderKeys()) {
            Expression orderExpr = orderKey.getExpr();
            if (orderExpr instanceof SlotReference) {
                if (projectInputSlots.contains(orderExpr)) {
                    newOrderKeys.add(orderKey);
                } else if (keyAsKey.containsKey(orderExpr)) {
                    newOrderKeys.add(orderKey.withExpression(keyAsKey.get(orderExpr)));
                } else {
                    match = false;
                    break;
                }
            } else {
                match = false;
                break;
            }
        }
        if (match) {
            if (project.getProjects().size() >= project.getInputSlots().size()) {
                topN = topN.withChildren(project.children()).withOrderKeys(newOrderKeys);
                project = (LogicalProject<Plan>) project.withChildren(topN);
                return project;
            }
        }
        return topN;
    }
}