SimpleAggCacheMgr.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.stats;

import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.TableIf;
import org.apache.doris.common.Config;
import org.apache.doris.common.ThreadPoolManager;
import org.apache.doris.qe.AutoCloseConnectContext;
import org.apache.doris.qe.StmtExecutor;
import org.apache.doris.statistics.ResultRow;
import org.apache.doris.statistics.util.StatisticsUtil;

import com.github.benmanes.caffeine.cache.AsyncCacheLoader;
import com.github.benmanes.caffeine.cache.AsyncLoadingCache;
import com.github.benmanes.caffeine.cache.Caffeine;
import com.google.common.annotations.VisibleForTesting;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.checkerframework.checker.nullness.qual.NonNull;

import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;

/**
 * Async cache that stores exact MIN/MAX/COUNT values for OlapTable,
 * used by {@code RewriteSimpleAggToConstantRule} to replace simple
 * aggregations with constant values.
 *
 * <p>MIN/MAX values are obtained by executing
 * {@code SELECT min(col), max(col) FROM table}, and COUNT values by
 * {@code SELECT count(*) FROM table}, both as internal SQL queries
 * inside FE. Results are cached with a version stamp derived from
 * {@code OlapTable.getVisibleVersionTime()}.
 * When a caller provides a version newer than the cached version,
 * the stale entry is evicted and a background reload is triggered.
 *
 * <p>Only numeric and date/datetime columns are cached for MIN/MAX;
 * aggregated columns are skipped.
 */
public class SimpleAggCacheMgr {

    // ======================== Public inner types ========================

    /**
     * Holds exact min and max values for a column as strings.
     */
    public static class ColumnMinMax {
        private final String minValue;
        private final String maxValue;

        public ColumnMinMax(String minValue, String maxValue) {
            this.minValue = minValue;
            this.maxValue = maxValue;
        }

        public String minValue() {
            return minValue;
        }

        public String maxValue() {
            return maxValue;
        }
    }

    /**
     * Cache key identifying a column by its table ID and column name.
     */
    public static final class ColumnMinMaxKey {
        private final long tableId;
        private final String columnName;

        public ColumnMinMaxKey(long tableId, String columnName) {
            this.tableId = tableId;
            this.columnName = columnName;
        }

        public long getTableId() {
            return tableId;
        }

        public String getColumnName() {
            return columnName;
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (!(o instanceof ColumnMinMaxKey)) {
                return false;
            }
            ColumnMinMaxKey that = (ColumnMinMaxKey) o;
            return tableId == that.tableId && columnName.equalsIgnoreCase(that.columnName);
        }

        @Override
        public int hashCode() {
            return Objects.hash(tableId, columnName.toLowerCase());
        }

        @Override
        public String toString() {
            return "ColumnMinMaxKey{tableId=" + tableId + ", column=" + columnName + "}";
        }
    }

    private static class CacheValue {
        private final ColumnMinMax minMax;
        private final long version;

        CacheValue(ColumnMinMax minMax, long version) {
            this.minMax = minMax;
            this.version = version;
        }

        ColumnMinMax minMax() {
            return minMax;
        }

        long version() {
            return version;
        }
    }

    /**
     * Cached row count with version stamp.
     */
    private static class RowCountValue {
        private final long rowCount;
        private final long version;

        RowCountValue(long rowCount, long version) {
            this.rowCount = rowCount;
            this.version = version;
        }

        long rowCount() {
            return rowCount;
        }

        long version() {
            return version;
        }
    }

    private static final Logger LOG = LogManager.getLogger(SimpleAggCacheMgr.class);

    private static volatile SimpleAggCacheMgr INSTANCE;
    private static volatile SimpleAggCacheMgr TEST_INSTANCE;

    private final AsyncLoadingCache<ColumnMinMaxKey, Optional<CacheValue>> cache;
    private final AsyncLoadingCache<Long, Optional<RowCountValue>> rowCountCache;

    /**
     * Protected no-arg constructor for test subclassing.
     * Subclasses override {@link #getStats}, {@link #getRowCount}, etc.
     */
    protected SimpleAggCacheMgr() {
        this.cache = null;
        this.rowCountCache = null;
    }

    private SimpleAggCacheMgr(ExecutorService executor) {
        this.cache = Caffeine.newBuilder()
                .maximumSize(Config.stats_cache_size)
                .executor(executor)
                .buildAsync(new CacheLoader());
        this.rowCountCache = Caffeine.newBuilder()
                .maximumSize(Config.stats_cache_size)
                .executor(executor)
                .buildAsync(new RowCountLoader());
    }

    private static SimpleAggCacheMgr getInstance() {
        if (INSTANCE == null) {
            synchronized (SimpleAggCacheMgr.class) {
                if (INSTANCE == null) {
                    ExecutorService executor = ThreadPoolManager.newDaemonCacheThreadPool(
                            4, "simple-agg-cache-pool", true);
                    INSTANCE = new SimpleAggCacheMgr(executor);
                }
            }
        }
        return INSTANCE;
    }

    /**
     * Returns the singleton instance backed by async-loading cache,
     * or the test override if one has been set.
     */
    public static SimpleAggCacheMgr internalInstance() {
        SimpleAggCacheMgr test = TEST_INSTANCE;
        if (test != null) {
            return test;
        }
        return getInstance();
    }

    /**
     * Override used only in unit tests to inject a mock implementation.
     */
    @VisibleForTesting
    public static void setTestInstance(SimpleAggCacheMgr instance) {
        TEST_INSTANCE = instance;
    }

    /**
     * Reset the test override so that subsequent calls go back to the real cache.
     */
    @VisibleForTesting
    public static void clearTestInstance() {
        TEST_INSTANCE = null;
    }

    /**
     * Get the cached min/max for a column.
     */
    public Optional<ColumnMinMax> getStats(ColumnMinMaxKey key, long version) {
        CompletableFuture<Optional<CacheValue>> future = cache.get(key);
        if (future.isDone()) {
            try {
                Optional<CacheValue> cacheValue = future.get();
                if (cacheValue.isPresent()) {
                    CacheValue value = cacheValue.get();
                    if (value.version() >= version) {
                        return Optional.of(value.minMax());
                    }
                }
                // Either empty (load failed / version changed during load)
                // or stale — evict so next call triggers a fresh reload.
                cache.synchronous().invalidate(key);
            } catch (Exception e) {
                LOG.warn("Failed to get MinMax for column: {}, version: {}", key, version, e);
                cache.synchronous().invalidate(key);
            }
        }
        return Optional.empty();
    }

    /**
     * Evict the cached stats for a column, if present. Used when we know the data has changed
     */
    public void removeStats(ColumnMinMaxKey key) {
        cache.synchronous().invalidate(key);
    }

    /**
     * Get the cached row count for a table.
     */
    public OptionalLong getRowCount(long tableId, long version) {
        CompletableFuture<Optional<RowCountValue>> future = rowCountCache.get(tableId);
        if (future.isDone()) {
            try {
                Optional<RowCountValue> cached = future.get();
                if (cached.isPresent()) {
                    RowCountValue value = cached.get();
                    if (value.version() >= version) {
                        return OptionalLong.of(value.rowCount());
                    }
                }
                // Either empty (load failed / version changed during load)
                // or stale — evict so next call triggers a fresh reload.
                rowCountCache.synchronous().invalidate(tableId);
            } catch (Exception e) {
                LOG.warn("Failed to get row count for table: {}, version: {}", tableId, version, e);
                rowCountCache.synchronous().invalidate(tableId);
            }
        }
        return OptionalLong.empty();
    }

    /**
     * Generate the internal SQL for fetching exact min/max values.
     */
    @VisibleForTesting
    public static String genMinMaxSql(List<String> qualifiers, String columnName) {
        // qualifiers: [catalogName, dbName, tableName]
        String quotedCol = "`" + StatisticsUtil.escapeColumnName(columnName) + "`";
        String fullTable = "`" + qualifiers.get(0) + "`.`"
                + qualifiers.get(1) + "`.`"
                + qualifiers.get(2) + "`";
        return "SELECT min(" + quotedCol + "), max(" + quotedCol + ") FROM " + fullTable;
    }

    /**
     * Generate the internal SQL for fetching exact row count.
     */
    @VisibleForTesting
    public static String genCountSql(List<String> qualifiers) {
        String fullTable = "`" + qualifiers.get(0) + "`.`"
                + qualifiers.get(1) + "`.`"
                + qualifiers.get(2) + "`";
        return "SELECT count(*) FROM " + fullTable;
    }

    /**
     * Async cache loader that issues internal SQL queries to compute exact min/max.
     */
    protected static final class CacheLoader
            implements AsyncCacheLoader<ColumnMinMaxKey, Optional<CacheValue>> {

        @Override
        public @NonNull CompletableFuture<Optional<CacheValue>> asyncLoad(
                @NonNull ColumnMinMaxKey key, @NonNull Executor executor) {
            return CompletableFuture.supplyAsync(() -> {
                try {
                    return doLoad(key);
                } catch (Exception e) {
                    LOG.warn("Failed to load MinMax for column: {}", key, e);
                    return Optional.empty();
                }
            }, executor);
        }

        private Optional<CacheValue> doLoad(ColumnMinMaxKey key) throws Exception {
            // Look up the table by its ID
            TableIf tableIf = Env.getCurrentInternalCatalog().getTableByTableId(key.getTableId());
            if (!(tableIf instanceof OlapTable)) {
                return Optional.empty();
            }
            OlapTable olapTable = (OlapTable) tableIf;

            // Validate column exists and is eligible
            Column column = olapTable.getColumn(key.getColumnName());
            if (column == null) {
                return Optional.empty();
            }
            if (!column.getType().isNumericType() && !column.getType().isDateType()) {
                return Optional.empty();
            }
            if (column.isAggregated()) {
                return Optional.empty();
            }

            // Use table-level visibleVersion (strictly monotonic) for cache staleness check,
            // consistent with how the caller (RewriteSimpleAggToConstantRule) obtains the version.
            long version = olapTable.getVisibleVersion();

            // Build and execute internal SQL
            List<String> qualifiers = olapTable.getFullQualifiers();
            String sql = genMinMaxSql(qualifiers, column.getName());

            List<ResultRow> rows;
            try (AutoCloseConnectContext r = StatisticsUtil.buildConnectContext(false)) {
                r.connectContext.getSessionVariable().setPipelineTaskNum("1");
                // Disable our own rule to prevent infinite recursion:
                // this internal SQL goes through Nereids and would otherwise trigger
                // RewriteSimpleAggToConstantRule again.
                r.connectContext.getSessionVariable().setDisableNereidsRules(
                        "REWRITE_SIMPLE_AGG_TO_CONSTANT");
                StmtExecutor stmtExecutor = new StmtExecutor(r.connectContext, sql);
                rows = stmtExecutor.executeInternalQuery();
            }
            if (rows == null || rows.isEmpty()) {
                return Optional.empty();
            }
            ResultRow row = rows.get(0);
            String minVal = row.get(0);
            String maxVal = row.get(1);
            if (minVal == null || maxVal == null) {
                return Optional.empty();
            }
            // Re-check version after query execution to detect concurrent data changes.
            // If the version changed during the query, the result is unreliable.
            long versionAfter = olapTable.getVisibleVersion();
            if (versionAfter != version) {
                return Optional.empty();
            }
            return Optional.of(new CacheValue(new ColumnMinMax(minVal, maxVal), version));
        }
    }

    /**
     * Async cache loader that issues {@code SELECT count(*) FROM table}
     * to compute exact row counts.
     */
    protected static final class RowCountLoader
            implements AsyncCacheLoader<Long, Optional<RowCountValue>> {

        @Override
        public @NonNull CompletableFuture<Optional<RowCountValue>> asyncLoad(
                @NonNull Long tableId, @NonNull Executor executor) {
            return CompletableFuture.supplyAsync(() -> {
                try {
                    return doLoad(tableId);
                } catch (Exception e) {
                    LOG.warn("Failed to load row count for table: {}", tableId, e);
                    return Optional.empty();
                }
            }, executor);
        }

        private Optional<RowCountValue> doLoad(Long tableId) throws Exception {
            TableIf tableIf = Env.getCurrentInternalCatalog().getTableByTableId(tableId);
            if (!(tableIf instanceof OlapTable)) {
                return Optional.empty();
            }
            OlapTable olapTable = (OlapTable) tableIf;

            long version = olapTable.getVisibleVersion();

            List<String> qualifiers = olapTable.getFullQualifiers();
            String sql = genCountSql(qualifiers);

            List<ResultRow> rows;
            try (AutoCloseConnectContext r = StatisticsUtil.buildConnectContext(false)) {
                r.connectContext.getSessionVariable().setPipelineTaskNum("1");
                // Disable our own rule to prevent infinite recursion:
                // this internal SQL goes through Nereids and would otherwise trigger
                // RewriteSimpleAggToConstantRule again.
                r.connectContext.getSessionVariable().setDisableNereidsRules(
                        "REWRITE_SIMPLE_AGG_TO_CONSTANT");
                StmtExecutor stmtExecutor = new StmtExecutor(r.connectContext, sql);
                rows = stmtExecutor.executeInternalQuery();
            }
            if (rows == null || rows.isEmpty()) {
                return Optional.empty();
            }
            String countStr = rows.get(0).get(0);
            if (countStr == null) {
                return Optional.empty();
            }
            long count = Long.parseLong(countStr);
            // Re-check version after query execution to detect concurrent data changes.
            long versionAfter = olapTable.getVisibleVersion();
            if (versionAfter != version) {
                return Optional.empty();
            }
            return Optional.of(new RowCountValue(count, version));
        }
    }
}