CompressedMaterialize.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.analysis;

import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.scalar.DecodeAsVarchar;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsBigInt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsInt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsLargeInt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.EncodeAsSmallInt;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.CharacterType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.qe.ConnectContext;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

/**
 * convert string to int in order to improve performance for aggregation and sorting.
 *
 * 1. AGG
 * select A from T group by A
 * =>
 * select DecodeAsVarchar(encode_as_int(A)) from T group by encode_as_int(A)
 *
 * 2. Sort
 * select * from T order by A
 * =>
 * select * from T order by encode_as_int(A)
 */
public class CompressedMaterialize implements AnalysisRuleFactory {
    @Override
    public List<Rule> buildRules() {
        return ImmutableList.of(
                RuleType.COMPRESSED_MATERIALIZE_AGG.build(
                        logicalAggregate().when(a -> ConnectContext.get() != null
                                && ConnectContext.get().getSessionVariable().enableCompressMaterialize)
                        .then(this::compressedMaterializeAggregate)),
                RuleType.COMPRESSED_MATERIALIZE_SORT.build(
                        logicalSort().when(a -> ConnectContext.get() != null
                                && ConnectContext.get().getSessionVariable().enableCompressMaterialize)
                                .then(this::compressMaterializeSort)
                )
        );
    }

    private LogicalSort<Plan> compressMaterializeSort(LogicalSort<Plan> sort) {
        List<OrderKey> newOrderKeys = Lists.newArrayList();
        List<Expression> orderKeysToEncode = Lists.newArrayList();
        for (OrderKey orderKey : sort.getOrderKeys()) {
            Expression expr = orderKey.getExpr();
            Optional<Expression> encode = getEncodeExpression(expr);
            if (encode.isPresent()) {
                newOrderKeys.add(new OrderKey(encode.get(),
                        orderKey.isAsc(),
                        orderKey.isNullFirst()));
                orderKeysToEncode.add(expr);
            } else {
                newOrderKeys.add(orderKey);
            }
        }
        if (orderKeysToEncode.isEmpty()) {
            return sort;
        } else {
            sort = sort.withOrderKeys(newOrderKeys);
            return sort;
        }

    }

    private Optional<Expression> getEncodeExpression(Expression expression) {
        if (expression.isConstant()) {
            return Optional.empty();
        }
        DataType type = expression.getDataType();
        Expression encodeExpr = null;
        if (type instanceof CharacterType) {
            CharacterType ct = (CharacterType) type;
            if (ct.getLen() > 0) {
                // skip column from variant, like 'L.var["L_SHIPMODE"] AS TEXT'
                if (ct.getLen() < 2) {
                    encodeExpr = new EncodeAsSmallInt(expression);
                } else if (ct.getLen() < 4) {
                    encodeExpr = new EncodeAsInt(expression);
                } else if (ct.getLen() < 7) {
                    encodeExpr = new EncodeAsBigInt(expression);
                } else if (ct.getLen() < 15) {
                    encodeExpr = new EncodeAsLargeInt(expression);
                }
            }
        }
        return Optional.ofNullable(encodeExpr);
    }

    /*
    example:
    [support] select sum(v) from t group by substring(k, 1,2)
    [not support] select substring(k, 1,2), sum(v) from t group by substring(k, 1,2)
    [support] select k, sum(v) from t group by k
    [not support] select substring(k, 1,2), sum(v) from t group by k
    [support]  select A as B from T group by A
    */
    private Map<Expression, Expression> getEncodeGroupByExpressions(LogicalAggregate<Plan> aggregate) {
        Map<Expression, Expression> encodeGroupbyExpressions = Maps.newHashMap();
        for (Expression gb : aggregate.getGroupByExpressions()) {
            Optional<Expression> encodeExpr = getEncodeExpression(gb);
            encodeExpr.ifPresent(expression -> encodeGroupbyExpressions.put(gb, expression));
        }
        return encodeGroupbyExpressions;
    }

    private LogicalAggregate<Plan> compressedMaterializeAggregate(LogicalAggregate<Plan> aggregate) {
        Map<Expression, Expression> encodeGroupByExpressions = getEncodeGroupByExpressions(aggregate);
        if (!encodeGroupByExpressions.isEmpty()) {
            List<Expression> newGroupByExpressions = Lists.newArrayList();
            for (Expression gp : aggregate.getGroupByExpressions()) {
                newGroupByExpressions.add(encodeGroupByExpressions.getOrDefault(gp, gp));
            }
            List<NamedExpression> newOutputs = Lists.newArrayList();
            Map<Expression, Expression> decodeMap = new HashMap<>();
            for (Expression gp : encodeGroupByExpressions.keySet()) {
                decodeMap.put(gp, new DecodeAsVarchar(encodeGroupByExpressions.get(gp)));
            }
            for (NamedExpression out : aggregate.getOutputExpressions()) {
                Expression replaced = ExpressionUtils.replace(out, decodeMap);
                if (out != replaced) {
                    if (out instanceof SlotReference) {
                        newOutputs.add(new Alias(out.getExprId(), replaced, out.getName()));
                    } else if (out instanceof Alias) {
                        newOutputs.add(((Alias) out).withChildren(replaced.children()));
                    } else {
                        // should not reach here
                        Preconditions.checkArgument(false, "output abnormal: " + aggregate);
                    }
                } else {
                    newOutputs.add(out);
                }
            }
            aggregate = aggregate.withGroupByAndOutput(newGroupByExpressions, newOutputs);
        }
        return aggregate;
    }

    private Map<Expression, Expression> getEncodeGroupingSets(LogicalRepeat<Plan> repeat) {
        Map<Expression, Expression> encode = Maps.newHashMap();
        // the first grouping set contains all group by keys
        for (Expression gb : repeat.getGroupingSets().get(0)) {
            Optional<Expression> encodeExpr = getEncodeExpression(gb);
            encodeExpr.ifPresent(expression -> encode.put(gb, expression));
        }
        return encode;
    }

    private LogicalRepeat<Plan> compressMaterializeRepeat(LogicalRepeat<Plan> repeat) {
        Map<Expression, Expression> encode = getEncodeGroupingSets(repeat);
        if (encode.isEmpty()) {
            return repeat;
        }
        List<List<Expression>> newGroupingSets = Lists.newArrayList();
        for (int i = 0; i < repeat.getGroupingSets().size(); i++) {
            List<Expression> grouping = Lists.newArrayList();
            for (int j = 0; j < repeat.getGroupingSets().get(i).size(); j++) {
                Expression groupingExpr = repeat.getGroupingSets().get(i).get(j);
                grouping.add(encode.getOrDefault(groupingExpr, groupingExpr));
            }
            newGroupingSets.add(grouping);
        }
        List<NamedExpression> newOutputs = Lists.newArrayList();
        Map<Expression, Expression> decodeMap = new HashMap<>();
        for (Expression gp : encode.keySet()) {
            decodeMap.put(gp, new DecodeAsVarchar(encode.get(gp)));
        }
        for (NamedExpression out : repeat.getOutputExpressions()) {
            Expression replaced = ExpressionUtils.replace(out, decodeMap);
            if (out != replaced) {
                if (out instanceof SlotReference) {
                    newOutputs.add(new Alias(out.getExprId(), replaced, out.getName()));
                } else if (out instanceof Alias) {
                    newOutputs.add(((Alias) out).withChildren(replaced.children()));
                } else {
                    // should not reach here
                    Preconditions.checkArgument(false, "output abnormal: " + repeat);
                }
            } else {
                newOutputs.add(out);
            }
        }
        repeat = repeat.withGroupSetsAndOutput(newGroupingSets, newOutputs);
        return repeat;
    }
}