IndexedPriorityQueue.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.
// This file is copied from
// https://github.com/trinodb/trino/blob/master/core/trino-main/src/main/java/io/trino/execution/resourcegroups/IndexedPriorityQueue.java
// and modified by Doris

package org.apache.doris.common;

import com.google.common.base.Preconditions;
import com.google.common.collect.Iterators;

import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.TreeSet;

/**
 * A priority queue with constant time contains(E) and log time remove(E)
 * Ties are broken by insertion order
 */
public final class IndexedPriorityQueue<E>
        implements UpdateablePriorityQueue<E> {
    private final Map<E, Entry<E>> index = new HashMap<>();
    private final Set<Entry<E>> queue;
    private long generation;

    public IndexedPriorityQueue() {
        this(PriorityOrdering.HIGH_TO_LOW);
    }

    public IndexedPriorityQueue(PriorityOrdering priorityOrdering) {
        switch (priorityOrdering) {
            case LOW_TO_HIGH:
                queue = new TreeSet<>(
                        Comparator.comparingLong((Entry<E> entry) -> entry.getPriority())
                                .thenComparingLong(Entry::getGeneration));
                break;
            case HIGH_TO_LOW:
                queue = new TreeSet<>((entry1, entry2) -> {
                    int priorityComparison = Long.compare(entry2.getPriority(), entry1.getPriority());
                    if (priorityComparison != 0) {
                        return priorityComparison;
                    }
                    return Long.compare(entry1.getGeneration(), entry2.getGeneration());
                });
                break;
            default:
                throw new IllegalArgumentException();
        }
    }

    @Override
    public boolean addOrUpdate(E element, long priority) {
        Entry<E> entry = index.get(element);
        if (entry != null) {
            if (entry.getPriority() == priority) {
                return false;
            }
            queue.remove(entry);
            Entry<E> newEntry = new Entry<>(element, priority, entry.getGeneration());
            queue.add(newEntry);
            index.put(element, newEntry);
            return false;
        }
        Entry<E> newEntry = new Entry<>(element, priority, generation);
        generation++;
        queue.add(newEntry);
        index.put(element, newEntry);
        return true;
    }

    @Override
    public boolean contains(E element) {
        return index.containsKey(element);
    }

    @Override
    public boolean remove(E element) {
        Entry<E> entry = index.remove(element);
        if (entry != null) {
            queue.remove(entry);
            return true;
        }
        return false;
    }

    @Override
    public E poll() {
        Entry<E> entry = pollEntry();
        if (entry == null) {
            return null;
        }
        return entry.getValue();
    }

    @Override
    public E peek() {
        Entry<E> entry = peekEntry();
        if (entry == null) {
            return null;
        }
        return entry.getValue();
    }

    @Override
    public int size() {
        return queue.size();
    }

    @Override
    public boolean isEmpty() {
        return queue.isEmpty();
    }

    public Prioritized<E> getPrioritized(E element) {
        Entry<E> entry = index.get(element);
        if (entry == null) {
            return null;
        }

        return new Prioritized<>(entry.getValue(), entry.getPriority());
    }

    public Prioritized<E> pollPrioritized() {
        Entry<E> entry = pollEntry();
        if (entry == null) {
            return null;
        }
        return new Prioritized<>(entry.getValue(), entry.getPriority());
    }

    private Entry<E> pollEntry() {
        Iterator<Entry<E>> iterator = queue.iterator();
        if (!iterator.hasNext()) {
            return null;
        }
        Entry<E> entry = iterator.next();
        iterator.remove();
        Preconditions.checkState(index.remove(entry.getValue()) != null, "Failed to remove entry from index");
        return entry;
    }

    public Prioritized<E> peekPrioritized() {
        Entry<E> entry = peekEntry();
        if (entry == null) {
            return null;
        }
        return new Prioritized<>(entry.getValue(), entry.getPriority());
    }

    public Entry<E> peekEntry() {
        Iterator<Entry<E>> iterator = queue.iterator();
        if (!iterator.hasNext()) {
            return null;
        }
        return iterator.next();
    }

    @Override
    public Iterator<E> iterator() {
        return Iterators.transform(queue.iterator(), Entry::getValue);
    }

    public enum PriorityOrdering {
        LOW_TO_HIGH,
        HIGH_TO_LOW
    }

    private static final class Entry<E> {
        private final E value;
        private final long priority;
        private final long generation;

        private Entry(E value, long priority, long generation) {
            this.value = Objects.requireNonNull(value, "value is null");
            this.priority = priority;
            this.generation = generation;
        }

        public E getValue() {
            return value;
        }

        public long getPriority() {
            return priority;
        }

        public long getGeneration() {
            return generation;
        }
    }

    public static class Prioritized<V> {
        private final V value;
        private final long priority;

        public Prioritized(V value, long priority) {
            this.value = Objects.requireNonNull(value, "value is null");
            this.priority = priority;
        }

        public V getValue() {
            return value;
        }

        public long getPriority() {
            return priority;
        }
    }
}