SqlBlockRuleMgr.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.blockrule;

import org.apache.doris.analysis.AlterSqlBlockRuleStmt;
import org.apache.doris.analysis.CreateSqlBlockRuleStmt;
import org.apache.doris.analysis.DropSqlBlockRuleStmt;
import org.apache.doris.analysis.ShowSqlBlockRuleStmt;
import org.apache.doris.catalog.Env;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.common.Config;
import org.apache.doris.common.DdlException;
import org.apache.doris.common.UserException;
import org.apache.doris.common.io.Text;
import org.apache.doris.common.io.Writable;
import org.apache.doris.common.util.SqlBlockUtil;
import org.apache.doris.metric.MetricRepo;
import org.apache.doris.mysql.privilege.Auth;
import org.apache.doris.persist.gson.GsonUtils;
import org.apache.doris.qe.ConnectContext;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.gson.annotations.SerializedName;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.stream.Collectors;

/**
 * Manage SqlBlockRule.
 **/
public class SqlBlockRuleMgr implements Writable {
    private static final Logger LOG = LogManager.getLogger(SqlBlockRuleMgr.class);

    private ReentrantReadWriteLock lock = new ReentrantReadWriteLock(true);

    @SerializedName(value = "nameToSqlBlockRuleMap")
    private Map<String, SqlBlockRule> nameToSqlBlockRuleMap = Maps.newConcurrentMap();

    private void writeLock() {
        lock.writeLock().lock();
    }

    private void writeUnlock() {
        lock.writeLock().unlock();
    }

    /**
     * Judge whether exist rule by ruleName.
     **/
    public boolean existRule(String name) {
        return nameToSqlBlockRuleMap.containsKey(name);
    }

    /**
     * Get SqlBlockRule by show stmt.
     **/
    public List<SqlBlockRule> getSqlBlockRule(ShowSqlBlockRuleStmt stmt) throws AnalysisException {
        String ruleName = stmt.getRuleName();
        return getSqlBlockRule(ruleName);
    }

    /**
     * Get SqlBlockRule by rulename.
     **/
    public List<SqlBlockRule> getSqlBlockRule(String ruleName) throws AnalysisException {
        if (StringUtils.isNotEmpty(ruleName)) {
            if (nameToSqlBlockRuleMap.containsKey(ruleName)) {
                SqlBlockRule sqlBlockRule = nameToSqlBlockRuleMap.get(ruleName);
                return Lists.newArrayList(sqlBlockRule);
            }
            return Lists.newArrayList();
        }
        return Lists.newArrayList(nameToSqlBlockRuleMap.values());
    }

    /**
     * Check limitation's  effectiveness of a SqlBlockRule.
     **/
    private static void verifyLimitations(SqlBlockRule sqlBlockRule) throws DdlException {
        if (sqlBlockRule.getPartitionNum() < 0) {
            throw new DdlException("the value of partition_num can't be a negative");
        }
        if (sqlBlockRule.getTabletNum() < 0) {
            throw new DdlException("the value of tablet_num can't be a negative");
        }
        if (sqlBlockRule.getCardinality() < 0) {
            throw new DdlException("the value of cardinality can't be a negative");
        }
    }

    /**
     * Create SqlBlockRule for create stmt.
     **/
    public void createSqlBlockRule(CreateSqlBlockRuleStmt stmt) throws UserException {
        createSqlBlockRule(SqlBlockRule.fromCreateStmt(stmt), stmt.isIfNotExists());
    }

    public void createSqlBlockRule(SqlBlockRule sqlBlockRule, boolean isIfNotExists) throws UserException {
        writeLock();
        try {
            String ruleName = sqlBlockRule.getName();
            if (existRule(ruleName)) {
                if (isIfNotExists) {
                    return;
                }
                throw new DdlException("the sql block rule " + ruleName + " already create");
            }
            verifyLimitations(sqlBlockRule);
            unprotectedAdd(sqlBlockRule);
            Env.getCurrentEnv().getEditLog().logCreateSqlBlockRule(sqlBlockRule);
        } finally {
            writeUnlock();
        }
    }

    /**
     * Add local cache when receive editLog.
     **/
    public void replayCreate(SqlBlockRule sqlBlockRule) {
        unprotectedAdd(sqlBlockRule);
        LOG.info("replay create sql block rule: {}", sqlBlockRule);
    }

    /**
     * Alter SqlBlockRule for alter stmt.
     **/
    public void alterSqlBlockRule(AlterSqlBlockRuleStmt stmt) throws AnalysisException, DdlException {
        alterSqlBlockRule(SqlBlockRule.fromAlterStmt(stmt));
    }

    public void alterSqlBlockRule(SqlBlockRule sqlBlockRule) throws AnalysisException, DdlException {
        writeLock();
        try {
            String ruleName = sqlBlockRule.getName();
            if (!existRule(ruleName)) {
                throw new DdlException("the sql block rule " + ruleName + " not exist");
            }
            SqlBlockRule originRule = nameToSqlBlockRuleMap.get(ruleName);

            if (sqlBlockRule.getSql().equals(CreateSqlBlockRuleStmt.STRING_NOT_SET)) {
                sqlBlockRule.setSql(originRule.getSql());
            }
            if (sqlBlockRule.getSqlHash().equals(CreateSqlBlockRuleStmt.STRING_NOT_SET)) {
                sqlBlockRule.setSqlHash(originRule.getSqlHash());
            }
            if (sqlBlockRule.getPartitionNum().equals(AlterSqlBlockRuleStmt.LONG_NOT_SET)) {
                sqlBlockRule.setPartitionNum(originRule.getPartitionNum());
            }
            if (sqlBlockRule.getTabletNum().equals(AlterSqlBlockRuleStmt.LONG_NOT_SET)) {
                sqlBlockRule.setTabletNum(originRule.getTabletNum());
            }
            if (sqlBlockRule.getCardinality().equals(AlterSqlBlockRuleStmt.LONG_NOT_SET)) {
                sqlBlockRule.setCardinality(originRule.getCardinality());
            }
            if (sqlBlockRule.getGlobal() == null) {
                sqlBlockRule.setGlobal(originRule.getGlobal());
            }
            if (sqlBlockRule.getEnable() == null) {
                sqlBlockRule.setEnable(originRule.getEnable());
            }
            verifyLimitations(sqlBlockRule);
            SqlBlockUtil.checkAlterValidate(sqlBlockRule);

            unprotectedUpdate(sqlBlockRule);
            Env.getCurrentEnv().getEditLog().logAlterSqlBlockRule(sqlBlockRule);
        } finally {
            writeUnlock();
        }
    }

    public void replayAlter(SqlBlockRule sqlBlockRule) {
        unprotectedUpdate(sqlBlockRule);
        LOG.info("replay alter sql block rule: {}", sqlBlockRule);
    }

    private void unprotectedUpdate(SqlBlockRule sqlBlockRule) {
        nameToSqlBlockRuleMap.put(sqlBlockRule.getName(), sqlBlockRule);
    }

    private void unprotectedAdd(SqlBlockRule sqlBlockRule) {
        nameToSqlBlockRuleMap.put(sqlBlockRule.getName(), sqlBlockRule);
    }

    /**
     * Drop SqlBlockRule for drop stmt.
     **/
    public void dropSqlBlockRule(DropSqlBlockRuleStmt stmt) throws DdlException {
        dropSqlBlockRule(stmt.getRuleNames(), stmt.isIfExists());
    }

    public void dropSqlBlockRule(List<String> ruleNames, boolean isIfExists) throws DdlException {
        writeLock();
        try {
            for (String ruleName : ruleNames) {
                if (!existRule(ruleName)) {
                    if (isIfExists) {
                        continue;
                    }
                    throw new DdlException("the sql block rule " + ruleName + " not exist");
                }
            }
            unprotectedDrop(ruleNames);
            Env.getCurrentEnv().getEditLog().logDropSqlBlockRule(ruleNames);
        } finally {
            writeUnlock();
        }
    }

    public void replayDrop(List<String> ruleNames) {
        unprotectedDrop(ruleNames);
        LOG.info("replay drop sql block ruleNames: {}", ruleNames);
    }

    public void unprotectedDrop(List<String> ruleNames) {
        ruleNames.forEach(name -> nameToSqlBlockRuleMap.remove(name));
    }

    /**
     * Match SQL according to rules.
     **/
    public void matchSql(String originSql, String sqlHash, String user) throws AnalysisException {
        if (Config.sql_block_rule_ignore_admin && (Auth.ROOT_USER.equals(user) || Auth.ADMIN_USER.equals(user))) {
            return;
        }
        if (ConnectContext.get() != null
                && ConnectContext.get().getSessionVariable().internalSession) {
            return;
        }
        // match global rule
        List<SqlBlockRule> globalRules =
                nameToSqlBlockRuleMap.values().stream().filter(SqlBlockRule::getGlobal).collect(Collectors.toList());
        for (SqlBlockRule rule : globalRules) {
            matchSql(rule, originSql, sqlHash);
        }
        // match user rule
        String[] bindSqlBlockRules = Env.getCurrentEnv().getAuth().getSqlBlockRules(user);
        for (String ruleName : bindSqlBlockRules) {
            SqlBlockRule rule = nameToSqlBlockRuleMap.get(ruleName);
            if (rule == null) {
                continue;
            }
            matchSql(rule, originSql, sqlHash);
        }
    }

    private void matchSql(SqlBlockRule rule, String originSql, String sqlHash) throws AnalysisException {
        if (rule.getEnable()) {
            if (StringUtils.isNotEmpty(rule.getSqlHash()) && !SqlBlockUtil.STRING_DEFAULT.equals(rule.getSqlHash())
                    && rule.getSqlHash().equals(sqlHash)) {
                MetricRepo.COUNTER_HIT_SQL_BLOCK_RULE.increase(1L);
                throw new AnalysisException("sql match hash sql block rule: " + rule.getName());
            } else if (StringUtils.isNotEmpty(rule.getSql()) && !SqlBlockUtil.STRING_DEFAULT.equals(rule.getSql())
                    && rule.getSqlPattern() != null && rule.getSqlPattern().matcher(originSql).find()) {
                MetricRepo.COUNTER_HIT_SQL_BLOCK_RULE.increase(1L);
                throw new AnalysisException("sql match regex sql block rule: " + rule.getName());
            }
        }
    }

    /**
     * Check number whether legal by user.
     **/
    public void checkLimitations(Long partitionNum, Long tabletNum, Long cardinality, String user)
            throws AnalysisException {
        if (ConnectContext.get().getSessionVariable().internalSession) {
            return;
        }
        // match global rule
        for (SqlBlockRule rule : nameToSqlBlockRuleMap.values()) {
            if (rule.getGlobal()) {
                checkLimitations(rule, partitionNum, tabletNum, cardinality);
            }
        }
        // match user rule
        String[] bindSqlBlockRules = Env.getCurrentEnv().getAuth().getSqlBlockRules(user);
        for (String ruleName : bindSqlBlockRules) {
            SqlBlockRule rule = nameToSqlBlockRuleMap.get(ruleName);
            if (rule == null) {
                continue;
            }
            checkLimitations(rule, partitionNum, tabletNum, cardinality);
        }
    }

    /**
     * Check number whether legal by SqlBlockRule.
     **/
    private void checkLimitations(SqlBlockRule rule, Long partitionNum, Long tabletNum, Long cardinality)
            throws AnalysisException {
        if (rule.getPartitionNum() == 0 && rule.getTabletNum() == 0 && rule.getCardinality() == 0) {
            return;
        } else if (rule.getEnable()) {
            if ((rule.getPartitionNum() != 0 && rule.getPartitionNum() < partitionNum) || (rule.getTabletNum() != 0
                    && rule.getTabletNum() < tabletNum) || (rule.getCardinality() != 0
                    && rule.getCardinality() < cardinality)) {
                MetricRepo.COUNTER_HIT_SQL_BLOCK_RULE.increase(1L);
                if (rule.getPartitionNum() < partitionNum && rule.getPartitionNum() != 0) {
                    throw new AnalysisException(
                            "sql hits sql block rule: " + rule.getName() + ", reach partition_num : "
                                    + rule.getPartitionNum());
                } else if (rule.getTabletNum() < tabletNum && rule.getTabletNum() != 0) {
                    throw new AnalysisException("sql hits sql block rule: " + rule.getName() + ", reach tablet_num : "
                            + rule.getTabletNum());
                } else if (rule.getCardinality() < cardinality && rule.getCardinality() != 0) {
                    throw new AnalysisException("sql hits sql block rule: " + rule.getName() + ", reach cardinality : "
                            + rule.getCardinality());
                }
            }
        }
    }

    @Override
    public void write(DataOutput out) throws IOException {
        Text.writeString(out, GsonUtils.GSON.toJson(this));
    }

    public static SqlBlockRuleMgr read(DataInput in) throws IOException {
        String json = Text.readString(in);
        return GsonUtils.GSON.fromJson(json, SqlBlockRuleMgr.class);
    }
}