QueryCacheNormalizer.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.
// This file is copied from
// https://github.com/apache/impala/blob/branch-2.9.0/fe/src/main/java/org/apache/impala/PlanFragment.java
// and modified by Doris

package org.apache.doris.planner.normalize;

import org.apache.doris.analysis.DescriptorTable;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.Tablet;
import org.apache.doris.common.Pair;
import org.apache.doris.planner.AggregationNode;
import org.apache.doris.planner.OlapScanNode;
import org.apache.doris.planner.PlanFragment;
import org.apache.doris.planner.PlanNode;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.thrift.TNormalizedPlanNode;
import org.apache.doris.thrift.TQueryCacheParam;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Maps;
import org.apache.thrift.TSerializer;
import org.apache.thrift.protocol.TCompactProtocol;

import java.security.MessageDigest;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

/** QueryCacheNormalizer */
public class QueryCacheNormalizer implements Normalizer {
    private final PlanFragment fragment;
    private final DescriptorTable descriptorTable;
    private final NormalizedIdGenerator normalizedPlanIds =  new NormalizedIdGenerator();
    private final NormalizedIdGenerator normalizedTupleIds =  new NormalizedIdGenerator();
    private final NormalizedIdGenerator normalizedSlotIds =  new NormalizedIdGenerator();

    // result
    private final TQueryCacheParam queryCacheParam = new TQueryCacheParam();

    public QueryCacheNormalizer(PlanFragment fragment, DescriptorTable descriptorTable) {
        this.fragment = Objects.requireNonNull(fragment, "fragment can not be null");
        this.descriptorTable = Objects.requireNonNull(descriptorTable, "descriptorTable can not be null");
    }

    public Optional<TQueryCacheParam> normalize(ConnectContext context) {
        try {
            Optional<CachePoint> cachePoint = computeCachePoint();
            if (!cachePoint.isPresent()) {
                return Optional.empty();
            }
            List<TNormalizedPlanNode> normalizedDigestPlans = normalizePlanTree(context, cachePoint.get());
            byte[] digest = computeDigest(normalizedDigestPlans);
            return setQueryCacheParam(cachePoint.get(), digest, context);
        } catch (Throwable t) {
            return Optional.empty();
        }
    }

    @VisibleForTesting
    public List<TNormalizedPlanNode> normalizePlans(ConnectContext context) {
        Optional<CachePoint> cachePoint = computeCachePoint();
        if (!cachePoint.isPresent()) {
            return ImmutableList.of();
        }
        return normalizePlanTree(context, cachePoint.get());
    }

    private Optional<TQueryCacheParam> setQueryCacheParam(
            CachePoint cachePoint, byte[] digest, ConnectContext context) {
        queryCacheParam.setNodeId(cachePoint.cacheRoot.getId().asInt());
        queryCacheParam.setDigest(digest);
        queryCacheParam.setForceRefreshQueryCache(context.getSessionVariable().isQueryCacheForceRefresh());
        queryCacheParam.setEntryMaxBytes(context.getSessionVariable().getQueryCacheEntryMaxBytes());
        queryCacheParam.setEntryMaxRows(context.getSessionVariable().getQueryCacheEntryMaxRows());

        queryCacheParam.setOutputSlotMapping(
                cachePoint.cacheRoot.getOutputTupleIds()
                    .stream()
                    .flatMap(tupleId -> descriptorTable.getTupleDesc(tupleId).getSlots().stream())
                    .map(slot -> {
                        int slotId = slot.getId().asInt();
                        return Pair.of(slotId, normalizeSlotId(slotId));
                    })
                    .collect(Collectors.toMap(Pair::key, Pair::value))
        );

        return Optional.of(queryCacheParam);
    }

    private Optional<CachePoint> computeCachePoint() {
        if (!fragment.getTargetRuntimeFilterIds().isEmpty()) {
            return Optional.empty();
        }
        PlanNode planRoot = fragment.getPlanRoot();
        return doComputeCachePoint(planRoot);
    }

    private Optional<CachePoint> doComputeCachePoint(PlanNode planRoot) {
        if (planRoot instanceof AggregationNode) {
            PlanNode child = planRoot.getChild(0);
            if (child instanceof OlapScanNode) {
                return Optional.of(new CachePoint(planRoot, planRoot));
            } else if (child instanceof AggregationNode) {
                Optional<CachePoint> childCachePoint = doComputeCachePoint(child);
                if (childCachePoint.isPresent()) {
                    return Optional.of(new CachePoint(planRoot, planRoot));
                }
            }
        }
        return Optional.empty();
    }

    private List<TNormalizedPlanNode> normalizePlanTree(ConnectContext context, CachePoint cachePoint) {
        List<TNormalizedPlanNode> normalizedPlans = new ArrayList<>();
        doNormalizePlanTree(context, cachePoint.digestRoot, normalizedPlans);
        return normalizedPlans;
    }

    private void doNormalizePlanTree(
            ConnectContext context, PlanNode plan, List<TNormalizedPlanNode> normalizedPlans) {
        for (PlanNode child : plan.getChildren()) {
            doNormalizePlanTree(context, child, normalizedPlans);
        }
        normalizedPlans.add(plan.normalize(this));
    }

    public static byte[] computeDigest(List<TNormalizedPlanNode> normalizedDigestPlans) throws Exception {
        TSerializer serializer = new TSerializer(new TCompactProtocol.Factory());
        MessageDigest digest = MessageDigest.getInstance("SHA-256");

        for (TNormalizedPlanNode node : normalizedDigestPlans) {
            digest.update(serializer.serialize(node));
        }
        return digest.digest();
    }

    @Override
    public int normalizeSlotId(int slotId) {
        return normalizedSlotIds.normalize(slotId);
    }

    @Override
    public void setSlotIdToNormalizeId(int slotId, int normalizedId) {
        normalizedSlotIds.set(slotId, normalizedId);
    }

    @Override
    public int normalizeTupleId(int tupleId) {
        return normalizedTupleIds.normalize(tupleId);
    }

    @Override
    public int normalizePlanId(int planId) {
        return normalizedPlanIds.normalize(planId);
    }

    @Override
    public DescriptorTable getDescriptorTable() {
        return descriptorTable;
    }

    @Override
    public void setNormalizedPartitionPredicates(OlapScanNode olapScanNode, NormalizedPartitionPredicates predicates) {
        OlapTable olapTable = olapScanNode.getOlapTable();
        long selectIndexId = olapScanNode.getSelectedIndexId() == -1
                ? olapTable.getBaseIndexId()
                : olapScanNode.getSelectedIndexId();

        Map<Long, String> tabletToRange = Maps.newLinkedHashMap();
        for (Long partitionId : olapScanNode.getSelectedPartitionIds()) {
            Set<Long> tabletIds = olapTable.getPartition(partitionId)
                    .getIndex(selectIndexId)
                    .getTablets()
                    .stream()
                    .map(Tablet::getId)
                    .collect(Collectors.toSet());

            String filterRange = predicates.intersectPartitionRanges.get(partitionId);
            for (Long tabletId : tabletIds) {
                tabletToRange.put(tabletId, filterRange);
            }
        }
        queryCacheParam.setTabletToRange(tabletToRange);
    }

    private static class CachePoint {
        PlanNode digestRoot;
        PlanNode cacheRoot;

        public CachePoint(PlanNode digestRoot, PlanNode cacheRoot) {
            this.digestRoot = digestRoot;
            this.cacheRoot = cacheRoot;
        }
    }

    private static class NormalizedIdGenerator {
        private final AtomicInteger idGenerator = new AtomicInteger(0);
        private final Map<Integer, Integer> originIdToNormalizedId = Maps.newLinkedHashMap();

        public Integer normalize(Integer originId) {
            return originIdToNormalizedId.computeIfAbsent(originId, id -> idGenerator.getAndIncrement());
        }

        public void set(int originId, Integer normalizedId) {
            originIdToNormalizedId.put(originId, normalizedId);
        }
    }
}