SubgraphEnumerator.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.jobs.joinorder.hypergraph;

import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmapSubsetIterator;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.AbstractNode;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.node.DPhyperNode;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver.AbstractReceiver;
import org.apache.doris.qe.ConnectContext;

import com.google.common.base.Preconditions;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.stream.Collectors;

/**
 * This class enumerate all subgraph of HyperGraph. CSG means connected subgraph
 * and CMP means complement subgraph.
 * More details are in Paper: Dynamic Programming Strikes Back and Build Query Optimizer.
 */
public class SubgraphEnumerator {
    public static final Logger LOG = LogManager.getLogger(SubgraphEnumerator.class);

    // The receiver receives the csg and cmp and record them, named DPTable in paper
    AbstractReceiver receiver;
    // The enumerated hyperGraph
    HyperGraph hyperGraph;
    EdgeCalculator edgeCalculator;
    NeighborhoodCalculator neighborhoodCalculator;
    // These caches are used to avoid repetitive computation

    // trace enumerate
    private final boolean enableTrace = ConnectContext.get().getSessionVariable().enableDpHypTrace;
    private final StringBuilder traceBuilder = new StringBuilder();

    public SubgraphEnumerator(AbstractReceiver receiver, HyperGraph hyperGraph) {
        this.receiver = receiver;
        this.hyperGraph = hyperGraph;
    }

    /**
     * Entry function of enumerating hyperGraph
     *
     * @return whether the hyperGraph is enumerated successfully
     */
    public boolean enumerate() {
        if (enableTrace) {
            traceBuilder.append("Query Graph Graphviz: ").append(hyperGraph.toDottyHyperGraph()).append("\n");
        }
        receiver.reset();
        List<AbstractNode> nodes = hyperGraph.getNodes();
        // Init all nodes in Receiver
        for (AbstractNode node : nodes) {
            DPhyperNode dPhyperNode = (DPhyperNode) node;
            receiver.addGroup(node.getNodeMap(), dPhyperNode.getGroup());
        }
        int size = nodes.size();

        // Init edgeCalculator
        edgeCalculator = new EdgeCalculator(hyperGraph.getJoinEdges());
        for (AbstractNode node : nodes) {
            edgeCalculator.initSubgraph(node.getNodeMap());
        }

        // Init neighborhoodCalculator
        neighborhoodCalculator = new NeighborhoodCalculator();

        // We skip the last element because it can't generate valid csg-cmp pair
        long forbiddenNodes = LongBitmap.newBitmapBetween(0, size - 1);
        for (int i = size - 2; i >= 0; i--) {
            if (enableTrace) {
                traceBuilder.append("Starting main iteration at node[").append(i).append("]\n");
            }
            long csg = LongBitmap.newBitmap(i);
            forbiddenNodes = LongBitmap.unset(forbiddenNodes, i);
            if (!emitCsg(csg) || !enumerateCsgRec(csg, LongBitmap.clone(forbiddenNodes))) {
                return false;
            }
        }
        if (enableTrace) {
            LOG.info(traceBuilder.toString());
        }
        return true;
    }

    // The general purpose of EnumerateCsgRec is to extend a given set csg, which
    // induces a connected subgraph of G to a larger set with the same property.
    private boolean enumerateCsgRec(long csg, long forbiddenNodes) {
        long neighborhood = neighborhoodCalculator.calcNeighborhood(csg, forbiddenNodes, edgeCalculator);
        LongBitmapSubsetIterator subsetIterator = LongBitmap.getSubsetIterator(neighborhood);
        if (enableTrace) {
            traceBuilder.append("Expanding connected subgraph, subgraph=[").append(LongBitmap.toString(csg))
                    .append("], neighborhood=[").append(LongBitmap.toString(neighborhood)).append("], forbidden=[")
                    .append(LongBitmap.toString(forbiddenNodes)).append("]\n");
        }
        for (long subset : subsetIterator) {
            long newCsg = LongBitmap.newBitmapUnion(csg, subset);
            edgeCalculator.unionEdges(csg, subset);
            if (receiver.contain(newCsg)) {
                if (!emitCsg(newCsg)) {
                    return false;
                }
            }
        }
        forbiddenNodes = LongBitmap.or(forbiddenNodes, neighborhood);
        subsetIterator.reset();
        for (long subset : subsetIterator) {
            long newCsg = LongBitmap.newBitmapUnion(csg, subset);
            if (!enumerateCsgRec(newCsg, LongBitmap.clone(forbiddenNodes))) {
                return false;
            }
        }
        return true;
    }

    private boolean enumerateCmpRec(long csg, long cmp, long forbiddenNodes) {
        long neighborhood = neighborhoodCalculator.calcNeighborhood(cmp, forbiddenNodes, edgeCalculator);
        LongBitmapSubsetIterator subsetIterator = new LongBitmapSubsetIterator(neighborhood);
        if (enableTrace) {
            traceBuilder.append("Expanding complement subgraph, subgraph=[").append(LongBitmap.toString(cmp))
                    .append("], neighborhood=[").append(LongBitmap.toString(neighborhood)).append("], forbidden=[")
                    .append(LongBitmap.toString(forbiddenNodes)).append("]\n");
        }
        for (long subset : subsetIterator) {
            long newCmp = LongBitmap.newBitmapUnion(cmp, subset);
            // We need to check whether Cmp is connected and then try to find hyper edge
            edgeCalculator.unionEdges(cmp, subset);
            if (receiver.contain(newCmp)) {
                // We check all edges for finding an edge.
                List<JoinEdge> edges = edgeCalculator.connectCsgCmp(csg, newCmp);
                if (edges.isEmpty()) {
                    continue;
                }
                if (!receiver.emitCsgCmp(csg, newCmp, edges)) {
                    return false;
                }
            }
        }
        forbiddenNodes = LongBitmap.or(forbiddenNodes, neighborhood);
        subsetIterator.reset();
        for (long subset : subsetIterator) {
            long newCmp = LongBitmap.newBitmapUnion(cmp, subset);
            if (!enumerateCmpRec(csg, newCmp, LongBitmap.clone(forbiddenNodes))) {
                return false;
            }
        }
        return true;
    }

    // EmitCsg takes as an argument a non-empty, proper subset csg of HyperGraph , which
    // induces a connected subgraph. It is then responsible to generate the seeds for
    // all cmp such that (csg, cmp) becomes a csg-cmp-pair.
    private boolean emitCsg(long csg) {
        long forbiddenNodes = LongBitmap.newBitmapBetween(0, LongBitmap.nextSetBit(csg, 0));
        forbiddenNodes = LongBitmap.or(forbiddenNodes, csg);
        long neighborhoods = neighborhoodCalculator.calcNeighborhood(csg, LongBitmap.clone(forbiddenNodes),
                edgeCalculator);
        if (enableTrace && LongBitmap.getCardinality(csg) == 1) {
            traceBuilder.append("Emitting connected subgraph, subgraph=[").append(LongBitmap.toString(csg))
                    .append("], neighborhood=[").append(LongBitmap.toString(neighborhoods)).append("], forbidden=[")
                    .append(LongBitmap.toString(forbiddenNodes)).append("]\n");
        }
        for (int nodeIndex : LongBitmap.getReverseIterator(neighborhoods)) {
            long cmp = LongBitmap.newBitmap(nodeIndex);
            // whether there is an edge between csg and cmp
            List<JoinEdge> edges = edgeCalculator.connectCsgCmp(csg, cmp);
            if (!edges.isEmpty()) {
                if (!receiver.emitCsgCmp(csg, cmp, edges)) {
                    return false;
                }
            }

            // In order to avoid enumerate repeated cmp, e.g.,
            //       t1 (csg)
            //      /  \
            //     t2 - t3
            // for csg {t1}, we can get neighborhoods {t2, t3}
            // 1. The cmp is {t3} and expanded from {t3} to {t2, t3}
            // 2. The cmp is {t2} and expanded from {t2} to {t2, t3}
            // We don't want get {t2, t3} twice. So In first enumeration, we
            // can exclude {t2}
            long newForbiddenNodes = LongBitmap.newBitmapBetween(0, nodeIndex + 1);
            newForbiddenNodes = LongBitmap.and(newForbiddenNodes, neighborhoods);
            newForbiddenNodes = LongBitmap.or(newForbiddenNodes, forbiddenNodes);
            if (!enumerateCmpRec(csg, cmp, newForbiddenNodes)) {
                return false;
            }
        }
        return true;
    }

    static class NeighborhoodCalculator {
        // This function is used to calculate neighborhoods of given subgraph.
        // Though a direct way is to add all nodes u that satisfies:
        //              <u, v> \in E && v \in subgraph && v \intersect X = empty
        // We don't used it because they can cause some repeated subgraph when
        // expand csg and cmp. In fact, we just need a seed node that can be expanded
        // to all subgraph. That is any one node of hyper nodes. In fact, the neighborhoods
        // is the minimum set that we choose one node from above v.
        public long calcNeighborhood(long subgraph, long forbiddenNodes, EdgeCalculator edgeCalculator) {
            long neighborhoods = LongBitmap.newBitmap();
            for (Edge edge : edgeCalculator.foundSimpleEdgesContain(subgraph)) {
                neighborhoods = LongBitmap.or(neighborhoods, edge.getReferenceNodes());
            }
            forbiddenNodes = LongBitmap.or(forbiddenNodes, subgraph);
            neighborhoods = LongBitmap.andNot(neighborhoods, forbiddenNodes);
            forbiddenNodes = LongBitmap.or(forbiddenNodes, neighborhoods);
            for (Edge edge : edgeCalculator.foundComplexEdgesContain(subgraph)) {
                long left = edge.getLeftExtendedNodes();
                long right = edge.getRightExtendedNodes();
                if (LongBitmap.isSubset(left, subgraph) && !LongBitmap.isOverlap(right, forbiddenNodes)) {
                    neighborhoods = LongBitmap.set(neighborhoods, LongBitmap.lowestOneIndex(right));
                } else if (LongBitmap.isSubset(right, subgraph) && !LongBitmap.isOverlap(left, forbiddenNodes)) {
                    neighborhoods = LongBitmap.set(neighborhoods, LongBitmap.lowestOneIndex(left));
                }
            }
            return neighborhoods;
        }
    }

    static class EdgeCalculator {
        final List<JoinEdge> edges;
        // It cached all edges that contained by this subgraph, Note we always
        // use bitset store edge map because the number of edges can be very large
        // We split these into simple edges (only one node on each side) and complex edges (others)
        // because we can often quickly discard all simple edges by testing the set of interesting nodes
        // against the “simple_neighborhood” bitmap. These data will be calculated before enumerate.

        HashMap<Long, BitSet> containSimpleEdges = new HashMap<>();
        HashMap<Long, BitSet> containComplexEdges = new HashMap<>();
        // It cached all edges that overlap by this subgraph. All this edges must be
        // complex edges
        HashMap<Long, BitSet> overlapEdges = new HashMap<>();

        EdgeCalculator(List<JoinEdge> edges) {
            this.edges = edges;
        }

        public void initSubgraph(long subgraph) {
            BitSet simpleContains = new BitSet();
            BitSet complexContains = new BitSet();
            BitSet overlaps = new BitSet();
            for (Edge edge : edges) {
                if (isContainEdge(subgraph, edge)) {
                    if (edge.isSimple()) {
                        simpleContains.set(edge.getIndex());
                    } else {
                        complexContains.set(edge.getIndex());
                    }
                } else if (isOverlapEdge(subgraph, edge)) {
                    overlaps.set(edge.getIndex());
                }
            }
            if (containSimpleEdges.containsKey(subgraph)) {
                complexContains.or(containComplexEdges.get(subgraph));
                simpleContains.or(containSimpleEdges.get(subgraph));
            }
            if (overlapEdges.containsKey(subgraph)) {
                overlaps.or(overlapEdges.get(subgraph));
            }
            overlapEdges.put(subgraph, overlaps);
            containSimpleEdges.put(subgraph, simpleContains);
            containComplexEdges.put(subgraph, complexContains);
        }

        public void unionEdges(long bitmap1, long bitmap2) {
            // When union two sub graphs, we only need to check overlap edges.
            // However, if all reference nodes are contained by the subgraph,
            // we should remove it.
            if (!containSimpleEdges.containsKey(bitmap1)) {
                initSubgraph(bitmap1);
            }
            if (!containSimpleEdges.containsKey(bitmap2)) {
                initSubgraph(bitmap2);
            }
            long subgraph = LongBitmap.newBitmapUnion(bitmap1, bitmap2);
            if (containSimpleEdges.containsKey(subgraph)) {
                return;
            }
            BitSet simpleContains = new BitSet();
            simpleContains.or(containSimpleEdges.get(bitmap1));
            simpleContains.or(containSimpleEdges.get(bitmap2));
            BitSet complexContains = new BitSet();
            complexContains.or(containComplexEdges.get(bitmap1));
            complexContains.or(containComplexEdges.get(bitmap2));
            BitSet overlaps = new BitSet();
            overlaps.or(overlapEdges.get(bitmap1));
            overlaps.or(overlapEdges.get(bitmap2));
            for (int index : overlaps.stream().toArray()) {
                Edge edge = edges.get(index);
                if (isContainEdge(subgraph, edge)) {
                    overlaps.set(index, false);
                    if (edge.isSimple()) {
                        simpleContains.set(index);
                    } else {
                        complexContains.set(index);
                    }
                }
            }
            simpleContains = removeInvalidEdges(subgraph, simpleContains);
            complexContains = removeInvalidEdges(subgraph, complexContains);
            containSimpleEdges.put(subgraph, simpleContains);
            containComplexEdges.put(subgraph, complexContains);
            overlapEdges.put(subgraph, overlaps);
        }

        public List<JoinEdge> connectCsgCmp(long csg, long cmp) {
            Preconditions.checkArgument(
                    containSimpleEdges.containsKey(csg) && containSimpleEdges.containsKey(cmp));
            List<JoinEdge> foundEdges = new ArrayList<>();
            BitSet edgeMap = new BitSet();
            edgeMap.or(containSimpleEdges.get(csg));
            edgeMap.and(containSimpleEdges.get(cmp));
            BitSet complexes = new BitSet();
            complexes.or(containComplexEdges.get(csg));
            complexes.and(containComplexEdges.get(cmp));
            edgeMap.or(complexes);
            edgeMap.stream().forEach(index -> foundEdges.add(edges.get(index)));
            return foundEdges;
        }

        public List<Edge> foundEdgesContain(long subgraph) {
            BitSet edgeMap = containSimpleEdges.get(subgraph);
            Preconditions.checkState(edgeMap != null);
            edgeMap.or(containComplexEdges.get(subgraph));
            return edgeMap.stream().mapToObj(edges::get).collect(Collectors.toList());
        }

        public List<Edge> foundSimpleEdgesContain(long subgraph) {
            if (!containSimpleEdges.containsKey(subgraph)) {
                return Collections.emptyList();
            }
            BitSet edgeMap = containSimpleEdges.get(subgraph);
            return edgeMap.stream().mapToObj(edges::get).collect(Collectors.toList());
        }

        public List<Edge> foundComplexEdgesContain(long subgraph) {
            if (!containComplexEdges.containsKey(subgraph)) {
                return Collections.emptyList();
            }
            BitSet edgeMap = containComplexEdges.get(subgraph);
            return edgeMap.stream().mapToObj(edges::get).collect(Collectors.toList());
        }

        private boolean isContainEdge(long subgraph, Edge edge) {
            int containLeft = LongBitmap.isSubset(edge.getLeftExtendedNodes(), subgraph) ? 0 : 1;
            int containRight = LongBitmap.isSubset(edge.getRightExtendedNodes(), subgraph) ? 0 : 1;
            return containLeft + containRight == 1;
        }

        private boolean isOverlapEdge(long subgraph, Edge edge) {
            int overlapLeft = LongBitmap.isOverlap(edge.getLeftExtendedNodes(), subgraph) ? 0 : 1;
            int overlapRight = LongBitmap.isOverlap(edge.getRightExtendedNodes(), subgraph) ? 0 : 1;
            return overlapLeft + overlapRight == 1;
        }

        private BitSet removeInvalidEdges(long subgraph, BitSet edgeMap) {
            for (int index : edgeMap.stream().toArray()) {
                Edge edge = edges.get(index);
                if (!isOverlapEdge(subgraph, edge)) {
                    edgeMap.set(index, false);
                }
            }
            return edgeMap;
        }
    }
}