AuditLogHelper.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.qe;

import org.apache.doris.analysis.NativeInsertStmt;
import org.apache.doris.analysis.Queriable;
import org.apache.doris.analysis.QueryStmt;
import org.apache.doris.analysis.SelectStmt;
import org.apache.doris.analysis.StatementBase;
import org.apache.doris.analysis.StmtType;
import org.apache.doris.analysis.ValueList;
import org.apache.doris.catalog.Env;
import org.apache.doris.cloud.qe.ComputeGroupException;
import org.apache.doris.cluster.ClusterNamespace;
import org.apache.doris.common.Config;
import org.apache.doris.common.util.DebugUtil;
import org.apache.doris.datasource.CatalogIf;
import org.apache.doris.datasource.InternalCatalog;
import org.apache.doris.metric.MetricRepo;
import org.apache.doris.mysql.MysqlCommand;
import org.apache.doris.nereids.analyzer.UnboundOneRowRelation;
import org.apache.doris.nereids.analyzer.UnboundTableSink;
import org.apache.doris.nereids.glue.LogicalPlanAdapter;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.InlineTable;
import org.apache.doris.nereids.trees.plans.commands.NeedAuditEncryption;
import org.apache.doris.nereids.trees.plans.commands.insert.InsertIntoTableCommand;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.plugin.AuditEvent;
import org.apache.doris.plugin.AuditEvent.AuditEventBuilder;
import org.apache.doris.plugin.AuditEvent.EventType;
import org.apache.doris.qe.QueryState.MysqlStateType;
import org.apache.doris.service.FrontendOptions;

import com.google.common.base.Strings;
import org.apache.commons.codec.digest.DigestUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.Charset;
import java.nio.charset.CharsetDecoder;
import java.nio.charset.CodingErrorAction;
import java.util.List;
import java.util.Optional;

public class AuditLogHelper {

    private static final Logger LOG = LogManager.getLogger(AuditLogHelper.class);

    /**
     * Add a new method to wrap original logAuditLog to catch all exceptions. Because write audit
     * log may write to a doris internal table, we may meet errors. We do not want this affect the
     * query process. Ignore this error and just write warning log.
     */
    public static void logAuditLog(ConnectContext ctx, String origStmt, StatementBase parsedStmt,
            org.apache.doris.proto.Data.PQueryStatistics statistics, boolean printFuzzyVariables) {
        try {
            logAuditLogImpl(ctx, origStmt, parsedStmt, statistics, printFuzzyVariables);
        } catch (Throwable t) {
            LOG.warn("Failed to write audit log.", t);
        }
    }

    /**
     * Truncate sql and if SQL is in the following situations, count the number of rows:
     * <ul>
     * <li>{@code insert into tbl values (1), (2), (3)}</li>
     * </ul>
     * The final SQL will be:
     * {@code insert into tbl values (1), (2 ...}
     */
    public static String handleStmt(String origStmt, StatementBase parsedStmt) {
        if (origStmt == null) {
            return null;
        }
        // 1. handle insert statement first
        Optional<String> res = handleInsertStmt(origStmt, parsedStmt);
        if (res.isPresent()) {
            return res.get();
        }

        // 2. handle other statement
        int maxLen = GlobalVariable.auditPluginMaxSqlLength;
        origStmt = truncateByBytes(origStmt, maxLen, " ... /* truncated. audit_plugin_max_sql_length=" + maxLen
                + " */");
        return origStmt.replace("\n", "\\n")
                .replace("\t", "\\t")
                .replace("\r", "\\r");
    }

    private static Optional<String> handleInsertStmt(String origStmt, StatementBase parsedStmt) {
        int rowCnt = 0;
        // old planner
        if (parsedStmt instanceof NativeInsertStmt) {
            QueryStmt queryStmt = ((NativeInsertStmt) parsedStmt).getQueryStmt();
            if (queryStmt instanceof SelectStmt) {
                ValueList list = ((SelectStmt) queryStmt).getValueList();
                if (list != null && list.getRows() != null) {
                    rowCnt = list.getRows().size();
                }
            }
        }
        // nereids planner
        if (parsedStmt instanceof LogicalPlanAdapter) {
            LogicalPlan plan = ((LogicalPlanAdapter) parsedStmt).getLogicalPlan();
            if (plan instanceof InsertIntoTableCommand) {
                LogicalPlan query = ((InsertIntoTableCommand) plan).getLogicalQuery();
                if (query instanceof UnboundTableSink) {
                    rowCnt = countValues(query.children());
                }
            }
        }
        if (rowCnt > 0) {
            // This is an insert statement.
            int maxLen = Math.max(0,
                    Math.min(GlobalVariable.auditPluginMaxInsertStmtLength, GlobalVariable.auditPluginMaxSqlLength));
            origStmt = truncateByBytes(origStmt, maxLen, " ... /* total " + rowCnt
                    + " rows, truncated. audit_plugin_max_insert_stmt_length=" + maxLen + " */");
            origStmt = origStmt.replace("\n", "\\n")
                    .replace("\t", "\\t")
                    .replace("\r", "\\r");
            return Optional.of(origStmt);
        } else {
            return Optional.empty();
        }
    }

    private static String truncateByBytes(String str, int maxLen, String suffix) {
        // use `getBytes().length` to get real byte length
        if (maxLen >= str.getBytes().length) {
            return str;
        }
        Charset utf8Charset = Charset.forName("UTF-8");
        CharsetDecoder decoder = utf8Charset.newDecoder();
        byte[] sb = str.getBytes();
        ByteBuffer buffer = ByteBuffer.wrap(sb, 0, maxLen);
        CharBuffer charBuffer = CharBuffer.allocate(maxLen);
        decoder.onMalformedInput(CodingErrorAction.IGNORE);
        decoder.decode(buffer, charBuffer, true);
        decoder.flush(charBuffer);
        return new String(charBuffer.array(), 0, charBuffer.position()) + suffix;
    }

    /**
     * When SQL is in the following situations, count the number of rows:
     * <ul>
     * <li>{@code insert into tbl values (1), (2), (3)}</li>
     * </ul>
     */
    private static int countValues(List<Plan> children) {
        if (children == null) {
            return 0;
        }
        int cnt = 0;
        for (Plan child : children) {
            if (child instanceof UnboundOneRowRelation) {
                cnt++;
            } else if (child instanceof InlineTable) {
                cnt += ((InlineTable) child).getConstantExprsList().size();
            } else if (child instanceof LogicalUnion) {
                cnt += countValues(child.children());
            }
        }
        return cnt;
    }

    private static void logAuditLogImpl(ConnectContext ctx, String origStmt, StatementBase parsedStmt,
            org.apache.doris.proto.Data.PQueryStatistics statistics, boolean printFuzzyVariables) {
        // slow query
        long endTime = System.currentTimeMillis();
        long elapseMs = endTime - ctx.getStartTime();
        CatalogIf catalog = ctx.getCurrentCatalog();
        String cloudCluster = "";
        try {
            if (Config.isCloudMode()) {
                cloudCluster = ctx.getCloudCluster(false);
            }
        } catch (ComputeGroupException e) {
            LOG.warn("Failed to get cloud cluster", e);
        }
        String cluster = Config.isCloudMode() ? cloudCluster : "";

        AuditEventBuilder auditEventBuilder = ctx.getAuditEventBuilder();
        // ATTN: MUST reset, otherwise, the same AuditEventBuilder instance will be used in the next query.
        auditEventBuilder.reset();
        auditEventBuilder
                .setTimestamp(ctx.getStartTime())
                .setClientIp(ctx.getClientIP())
                .setUser(ClusterNamespace.getNameFromFullName(ctx.getQualifiedUser()))
                .setSqlHash(ctx.getSqlHash())
                .setEventType(EventType.AFTER_QUERY)
                .setCtl(catalog == null ? InternalCatalog.INTERNAL_CATALOG_NAME : catalog.getName())
                .setDb(ClusterNamespace.getNameFromFullName(ctx.getDatabase()))
                .setState(ctx.getState().toString())
                .setErrorCode(ctx.getState().getErrorCode() == null ? 0 : ctx.getState().getErrorCode().getCode())
                .setErrorMessage((ctx.getState().getErrorMessage() == null ? "" :
                        ctx.getState().getErrorMessage().replace("\n", " ").replace("\t", " ")))
                .setQueryTime(elapseMs)
                .setScanBytes(statistics == null ? 0 : statistics.getScanBytes())
                .setScanRows(statistics == null ? 0 : statistics.getScanRows())
                .setSpillWriteBytesToLocalStorage(statistics == null ? 0 :
                        statistics.getSpillWriteBytesToLocalStorage())
                .setSpillReadBytesFromLocalStorage(statistics == null ? 0 :
                        statistics.getSpillReadBytesFromLocalStorage())
                .setCpuTimeMs(statistics == null ? 0 : statistics.getCpuMs())
                .setPeakMemoryBytes(statistics == null ? 0 : statistics.getMaxPeakMemoryBytes())
                .setReturnRows(ctx.getReturnRows())
                .setStmtId(ctx.getStmtId())
                .setQueryId(ctx.queryId() == null ? "NaN" : DebugUtil.printId(ctx.queryId()))
                .setCloudCluster(Strings.isNullOrEmpty(cluster) ? "UNKNOWN" : cluster)
                .setWorkloadGroup(ctx.getWorkloadGroupName())
                .setFuzzyVariables(!printFuzzyVariables ? "" : ctx.getSessionVariable().printFuzzyVariables())
                .setCommandType(ctx.getCommand().toString());

        if (ctx.getState().isQuery()) {
            if (MetricRepo.isInit) {
                if (!ctx.getSessionVariable().internalSession) {
                    MetricRepo.COUNTER_QUERY_ALL.increase(1L);
                    MetricRepo.USER_COUNTER_QUERY_ALL.getOrAdd(ctx.getQualifiedUser()).increase(1L);
                }
                try {
                    if (Config.isCloudMode()) {
                        cloudCluster = ctx.getCloudCluster(false);
                    }
                } catch (ComputeGroupException e) {
                    LOG.warn("Failed to get cloud cluster", e);
                    return;
                }
                MetricRepo.increaseClusterQueryAll(cloudCluster);
                if (ctx.getState().getStateType() == MysqlStateType.ERR
                        && ctx.getState().getErrType() != QueryState.ErrType.ANALYSIS_ERR) {
                    // err query
                    if (!ctx.getSessionVariable().internalSession) {
                        MetricRepo.COUNTER_QUERY_ERR.increase(1L);
                        MetricRepo.USER_COUNTER_QUERY_ERR.getOrAdd(ctx.getQualifiedUser()).increase(1L);
                        MetricRepo.increaseClusterQueryErr(cloudCluster);
                    }
                } else if (ctx.getState().getStateType() == MysqlStateType.OK
                        || ctx.getState().getStateType() == MysqlStateType.EOF) {
                    // ok query
                    if (!ctx.getSessionVariable().internalSession) {
                        MetricRepo.HISTO_QUERY_LATENCY.update(elapseMs);
                        MetricRepo.USER_HISTO_QUERY_LATENCY.getOrAdd(ctx.getQualifiedUser()).update(elapseMs);
                        MetricRepo.updateClusterQueryLatency(cloudCluster, elapseMs);
                    }

                    if (elapseMs > Config.qe_slow_log_ms) {
                        String sqlDigest = DigestUtils.md5Hex(((Queriable) parsedStmt).toDigest());
                        auditEventBuilder.setSqlDigest(sqlDigest);
                        MetricRepo.COUNTER_QUERY_SLOW.increase(1L);
                    }
                }
            }
            auditEventBuilder.setIsQuery(true)
                    .setScanBytesFromLocalStorage(
                            statistics == null ? 0 : statistics.getScanBytesFromLocalStorage())
                    .setScanBytesFromRemoteStorage(
                            statistics == null ? 0 : statistics.getScanBytesFromRemoteStorage());
        } else {
            auditEventBuilder.setIsQuery(false);
        }
        auditEventBuilder.setIsNereids(ctx.getState().isNereids);

        auditEventBuilder.setFeIp(FrontendOptions.getLocalHostAddress());

        boolean isAnalysisErr = ctx.getState().getStateType() == MysqlStateType.ERR
                && ctx.getState().getErrType() == QueryState.ErrType.ANALYSIS_ERR;
        String encryptSql = isAnalysisErr ? ctx.getState().getErrorMessage() : origStmt;
        // We put origin query stmt at the end of audit log, for parsing the log more convenient.
        if (parsedStmt instanceof LogicalPlanAdapter) {
            LogicalPlan logicalPlan = ((LogicalPlanAdapter) parsedStmt).getLogicalPlan();
            if ((logicalPlan instanceof NeedAuditEncryption)) {
                encryptSql = ((NeedAuditEncryption) logicalPlan).geneEncryptionSQL(origStmt);
            }
        } else {
            if (!ctx.getState().isQuery() && (parsedStmt != null && parsedStmt.needAuditEncryption())) {
                encryptSql = parsedStmt.toSql();
            }
        }
        auditEventBuilder.setStmt(handleStmt(encryptSql, parsedStmt));
        auditEventBuilder.setStmtType(getStmtType(parsedStmt));

        if (!Env.getCurrentEnv().isMaster()) {
            if (ctx.executor != null && ctx.executor.isForwardToMaster()) {
                auditEventBuilder.setState(ctx.executor.getProxyStatus());
                int proxyStatusCode = ctx.executor.getProxyStatusCode();
                if (proxyStatusCode != 0) {
                    auditEventBuilder.setErrorCode(proxyStatusCode);
                    auditEventBuilder.setErrorMessage(ctx.executor.getProxyErrMsg());
                }
            }
        }
        if (ctx.getCommand() == MysqlCommand.COM_STMT_PREPARE && ctx.getState().getErrorCode() == null) {
            auditEventBuilder.setState(String.valueOf(MysqlStateType.OK));
        }
        AuditEvent event = auditEventBuilder.build();
        Env.getCurrentEnv().getWorkloadRuntimeStatusMgr().submitFinishQueryToAudit(event);
        if (LOG.isDebugEnabled()) {
            LOG.debug("submit audit event: {}", event.queryId);
        }
    }

    private static String getStmtType(StatementBase stmt) {
        if (stmt == null) {
            return StmtType.OTHER.name();
        }
        if (stmt.isExplain()) {
            return StmtType.EXPLAIN.name();
        }
        if (stmt instanceof LogicalPlanAdapter) {
            return ((LogicalPlanAdapter) stmt).getLogicalPlan().stmtType().name();
        } else {
            return stmt.stmtType().name();
        }
    }
}