ImmutableEqualSet.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.util;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;

import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;

/**
 * A class representing an immutable set of elements with equivalence relations.
 */
public class ImmutableEqualSet<T> {
    private final Map<T, T> root;

    ImmutableEqualSet(Map<T, T> root) {
        this.root = ImmutableMap.copyOf(root);
    }

    public static <T> ImmutableEqualSet<T> empty() {
        return new ImmutableEqualSet<>(ImmutableMap.of());
    }

    /**
     * Builder for ImmutableEqualSet.
     */
    public static class Builder<T> {
        private Map<T, T> parent;

        Builder(Map<T, T> parent) {
            this.parent = parent;
        }

        public Builder() {
            this(new LinkedHashMap<>());
        }

        public Builder(ImmutableEqualSet<T> equalSet) {
            this(new LinkedHashMap<>(equalSet.root));
        }

        /**
         * replace all key value according replace map
         */
        public void replace(Map<T, T> replaceMap) {
            Map<T, T> newMap = new LinkedHashMap<>();
            for (Entry<T, T> entry : parent.entrySet()) {
                newMap.put(replaceMap.getOrDefault(entry.getKey(), entry.getKey()),
                        replaceMap.getOrDefault(entry.getValue(), entry.getValue()));
            }
            parent = newMap;
        }

        /**
         * Remove all not contain in containSet
         * @param containSet the set to contain
         */
        public void removeNotContain(Set<T> containSet) {
            List<Set<T>> equalSetList = calEqualSetList();
            this.parent.clear();
            for (Set<T> equalSet : equalSetList) {
                Set<T> intersect = Sets.intersection(containSet, equalSet);
                if (intersect.size() <= 1) {
                    continue;
                }
                Iterator<T> iterator = intersect.iterator();
                T first = intersect.iterator().next();
                while (iterator.hasNext()) {
                    T next = iterator.next();
                    this.addEqualPair(first, next);
                }
            }
        }

        /**
         * Add a equal pair
         */
        public void addEqualPair(T a, T b) {
            if (!parent.containsKey(a)) {
                parent.put(a, a);
            }
            if (!parent.containsKey(b)) {
                parent.put(b, b);
            }
            T root1 = findRoot(a);
            T root2 = findRoot(b);
            if (root1 != root2) {
                parent.put(root1, root2);
            }
        }

        /**
         * Calculate all equal set
         */
        public List<Set<T>> calEqualSetList() {
            parent.replaceAll((s, v) -> findRoot(s));
            return parent.values()
                    .stream()
                    .distinct()
                    .map(a -> {
                        T ra = parent.get(a);
                        return parent.keySet().stream()
                                .filter(t -> parent.get(t).equals(ra))
                                .collect(ImmutableSet.toImmutableSet());
                    }).collect(ImmutableList.toImmutableList());
        }

        public void addEqualSet(ImmutableEqualSet<T> equalSet) {
            this.parent.putAll(equalSet.root);
        }

        private T findRoot(T a) {
            if (a.equals(parent.get(a))) {
                return parent.get(a);
            }
            return findRoot(parent.get(a));
        }

        public ImmutableEqualSet<T> build() {
            ImmutableMap.Builder<T, T> foldMapBuilder = new ImmutableMap.Builder<>();
            for (T k : parent.keySet()) {
                foldMapBuilder.put(k, findRoot(k));
            }
            return new ImmutableEqualSet<>(foldMapBuilder.build());
        }
    }

    /**
     * Calculate equal set for a except self
     */
    public Set<T> calEqualSet(T a) {
        T ra = root.get(a);
        return root.keySet().stream()
                .filter(t -> root.get(t).equals(ra) && !t.equals(a))
                .collect(ImmutableSet.toImmutableSet());
    }

    public boolean isEmpty() {
        return root.isEmpty();
    }

    /**
     * Calculate all equal set
     */
    public List<Set<T>> calEqualSetList() {
        return root.values()
                .stream()
                .distinct()
                .map(a -> {
                    T ra = root.get(a);
                    return root.keySet().stream()
                            .filter(t -> root.get(t).equals(ra))
                            .collect(ImmutableSet.toImmutableSet());
                }).collect(ImmutableList.toImmutableList());
    }

    public Set<T> getAllItemSet() {
        return ImmutableSet.copyOf(root.keySet());
    }

    public boolean isEqual(T l, T r) {
        if (!root.containsKey(l) || !root.containsKey(r)) {
            return false;
        }
        return root.get(l) == root.get(r);
    }

    @Override
    public String toString() {
        return root.toString();
    }
}