FuncDeps.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.properties;

import org.apache.doris.nereids.trees.expressions.Slot;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

/**
 * Function dependence items.
 */
public class FuncDeps {
    /**FuncDepsItem*/
    public static class FuncDepsItem {
        public final Set<Slot> determinants;
        public final Set<Slot> dependencies;

        public FuncDepsItem(Set<Slot> determinants, Set<Slot> dependencies) {
            this.determinants = ImmutableSet.copyOf(determinants);
            this.dependencies = ImmutableSet.copyOf(dependencies);
        }

        @Override
        public String toString() {
            return String.format("%s -> %s", determinants, dependencies);
        }

        @Override
        public boolean equals(Object other) {
            if (!(other instanceof FuncDepsItem)) {
                return false;
            }
            FuncDepsItem item = (FuncDepsItem) other;
            return item.determinants.equals(determinants) && item.dependencies.equals(dependencies);
        }

        @Override
        public int hashCode() {
            return Objects.hash(determinants, dependencies);
        }
    }

    private final Set<FuncDepsItem> items;
    // determinants -> dependencies
    private final Map<Set<Slot>, Set<Set<Slot>>> edges;
    // dependencies -> determinants
    private final Map<Set<Slot>, Set<Set<Slot>>> redges;

    public FuncDeps() {
        items = new HashSet<>();
        edges = new HashMap<>();
        redges = new HashMap<>();
    }

    public void addFuncItems(Set<Slot> determinants, Set<Slot> dependencies) {
        items.add(new FuncDepsItem(determinants, dependencies));
        edges.computeIfAbsent(determinants, k -> new HashSet<>());
        edges.get(determinants).add(dependencies);
        redges.computeIfAbsent(dependencies, k -> new HashSet<>());
        redges.get(dependencies).add(determinants);
    }

    public int size() {
        return items.size();
    }

    public boolean isEmpty() {
        return items.isEmpty();
    }

    private void dfs(Set<Slot> parent, Set<Set<Slot>> visited, Set<FuncDepsItem> circleItem) {
        visited.add(parent);
        if (!edges.containsKey(parent)) {
            return;
        }
        for (Set<Slot> child : edges.get(parent)) {
            if (visited.contains(child)) {
                circleItem.add(new FuncDepsItem(parent, child));
                continue;
            }
            dfs(child, visited, circleItem);
        }
    }

    // Find items that are not part of a circular dependency.
    // To keep the slots in requireOutputs, we need to always keep the edges that start with output slots.
    // Note: We reduce the last edge in a circular dependency,
    // so we need to traverse from parents that contain the required output slots.
    private Set<FuncDepsItem> findValidItems(Set<Slot> requireOutputs) {
        Set<FuncDepsItem> circleItem = new HashSet<>();
        Set<Set<Slot>> visited = new HashSet<>();
        Set<Set<Slot>> parentInOutput = edges.keySet().stream()
                .filter(requireOutputs::containsAll)
                .collect(Collectors.toSet());
        for (Set<Slot> parent : parentInOutput) {
            if (!visited.contains(parent)) {
                dfs(parent, visited, circleItem);
            }
        }
        Set<Set<Slot>> otherParent = edges.keySet().stream()
                .filter(parent -> !parentInOutput.contains(parent))
                .collect(Collectors.toSet());
        for (Set<Slot> parent : otherParent) {
            if (!visited.contains(parent)) {
                dfs(parent, visited, circleItem);
            }
        }
        return Sets.difference(items, circleItem);
    }

    /**
     * Reduces a given set of slot sets by eliminating dependencies based on valid functional dependency items.
     * <p>
     * This method works as follows:
     * 1. Find valid functional dependency items (those not part of circular dependencies).
     * 2. For each valid functional dependency item:
     *    - If both the determinants and dependencies are present in the current set of slots,
     *      mark the dependencies for elimination.
     * 3. Remove all marked dependencies from the set of slots.
     * </p>
     * <p>
     * Example:
     * Given:
     * - Initial slots: {{A}, {B}, {C}, {D}, {E}}
     * - Required outputs: {}
     * - validItems: {A} -> {B}, {B} -> {C}, {C} -> {D}, {D} -> {A}, {A} -> {E}
     *
     * Process:
     * 1. Start with minSlotSet = {{A}, {B}, {C}, {D}, {E}}
     * 2. For {A} -> {B}:
     *    - Both {A} and {B} are in minSlotSet, so mark {B} for elimination
     * 3. For {B} -> {C}:
     *    - Both {B} and {C} are in minSlotSet, so mark {C} for elimination
     * 4. For {C} -> {D}:
     *    - Both {C} and {D} are in minSlotSet, so mark {D} for elimination
     * 4. For {D} -> {E}:
     *    - Both {D} and {E} are in minSlotSet, so mark {E} for elimination
     *
     * Result: {{A}}
     * </p>
     *
     * @param slots the initial set of slot sets to be reduced
     * @param requireOutputs the set of slots that must be preserved in the output
     * @return the minimal set of slot sets after applying all possible reductions
     */
    public Set<Set<Slot>> eliminateDeps(Set<Set<Slot>> slots, Set<Slot> requireOutputs) {
        Set<Set<Slot>> minSlotSet = Sets.newHashSet(slots);
        Set<Set<Slot>> eliminatedSlots = new HashSet<>();
        Set<FuncDepsItem> validItems = findValidItems(requireOutputs);
        for (FuncDepsItem funcDepsItem : validItems) {
            if (minSlotSet.contains(funcDepsItem.dependencies)
                    && minSlotSet.contains(funcDepsItem.determinants)) {
                eliminatedSlots.add(funcDepsItem.dependencies);
            }
        }
        minSlotSet.removeAll(eliminatedSlots);
        return minSlotSet;
    }

    public boolean isFuncDeps(Set<Slot> dominate, Set<Slot> dependency) {
        return items.contains(new FuncDepsItem(dominate, dependency));
    }

    // ������������������������������������
    public boolean isCircleDeps(Set<Slot> dominate, Set<Slot> dependency) {
        return items.contains(new FuncDepsItem(dominate, dependency))
                && items.contains(new FuncDepsItem(dependency, dominate));
    }

    public Set<FuncDeps.FuncDepsItem> getItems() {
        return items;
    }

    public Map<Set<Slot>, Set<Set<Slot>>> getEdges() {
        return edges;
    }

    public Map<Set<Slot>, Set<Set<Slot>>> getREdges() {
        return redges;
    }

    /**
     * Finds all slot sets that have a bijective relationship with the given slot set.
     * Given edges containing:
     *   {A} -> {{B}, {C}}
     *   {B} -> {{A}, {D}}
     *   {C} -> {{A}}
     * When slot = {A}, returns {{B}} because {A} and {B} mutually determine each other.
     * {C} is not returned because {C} does not determine {A} (one-way dependency only).
     */
    public Set<Set<Slot>> findBijectionSlots(Set<Slot> slot) {
        Set<Set<Slot>> bijectionSlots = new HashSet<>();
        if (!edges.containsKey(slot)) {
            return bijectionSlots;
        }
        for (Set<Slot> dep : edges.get(slot)) {
            if (!edges.containsKey(dep)) {
                continue;
            }
            for (Set<Slot> det : edges.get(dep)) {
                if (det.equals(slot)) {
                    bijectionSlots.add(dep);
                }
            }
        }
        return bijectionSlots;
    }

    @Override
    public String toString() {
        return items.toString();
    }
}