LogicalCheckPolicy.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.plans.logical;
import org.apache.doris.analysis.UserIdentity;
import org.apache.doris.mysql.privilege.AccessControllerManager;
import org.apache.doris.mysql.privilege.DataMaskPolicy;
import org.apache.doris.mysql.privilege.RowFilterPolicy;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.SqlCacheContext;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.analyzer.UnboundAlias;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.PropagateFuncDeps;
import org.apache.doris.nereids.trees.plans.algebra.CatalogRelation;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.qe.ConnectContext;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import org.apache.commons.collections.CollectionUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
/**
* Logical Check Policy
*/
public class LogicalCheckPolicy<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_TYPE>
implements PropagateFuncDeps {
public LogicalCheckPolicy(CHILD_TYPE child) {
super(PlanType.LOGICAL_CHECK_POLICY, child);
}
public LogicalCheckPolicy(Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties, CHILD_TYPE child) {
super(PlanType.LOGICAL_CHECK_POLICY, groupExpression, logicalProperties, child);
}
@Override
public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
return visitor.visitLogicalCheckPolicy(this, context);
}
@Override
public List<? extends Expression> getExpressions() {
return ImmutableList.of();
}
@Override
public List<Slot> computeOutput() {
return child().getOutput();
}
@Override
public String toString() {
return Utils.toSqlString("LogicalCheckPolicy");
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
LogicalCheckPolicy that = (LogicalCheckPolicy) o;
return child().equals(that.child());
}
@Override
public int hashCode() {
return child().hashCode();
}
@Override
public Plan withGroupExpression(Optional<GroupExpression> groupExpression) {
return new LogicalCheckPolicy<>(groupExpression, Optional.of(getLogicalProperties()), child());
}
@Override
public Plan withGroupExprLogicalPropChildren(Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties, List<Plan> children) {
Preconditions.checkArgument(children.size() == 1);
return new LogicalCheckPolicy<>(groupExpression, logicalProperties, children.get(0));
}
@Override
public Plan withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 1);
return new LogicalCheckPolicy<>(children.get(0));
}
/**
* find related policy for logicalRelation.
*
* @param logicalRelation include tableName and dbName
* @param cascadesContext include information about user and policy
*/
public RelatedPolicy findPolicy(LogicalRelation logicalRelation, CascadesContext cascadesContext) {
if (!(logicalRelation instanceof CatalogRelation)) {
return RelatedPolicy.NO_POLICY;
}
ConnectContext connectContext = cascadesContext.getConnectContext();
AccessControllerManager accessManager = connectContext.getEnv().getAccessManager();
UserIdentity currentUserIdentity = connectContext.getCurrentUserIdentity();
if (currentUserIdentity.isRootUser() || currentUserIdentity.isAdminUser()) {
return RelatedPolicy.NO_POLICY;
}
CatalogRelation catalogRelation = (CatalogRelation) logicalRelation;
String ctlName = catalogRelation.getDatabase().getCatalog().getName();
String dbName = catalogRelation.getDatabase().getFullName();
String tableName = catalogRelation.getTable().getName();
NereidsParser nereidsParser = new NereidsParser();
ImmutableList.Builder<NamedExpression> dataMasks
= ImmutableList.builderWithExpectedSize(logicalRelation.getOutput().size());
StatementContext statementContext = cascadesContext.getStatementContext();
Optional<SqlCacheContext> sqlCacheContext = statementContext.getSqlCacheContext();
boolean hasDataMask = false;
for (Slot slot : logicalRelation.getOutput()) {
Optional<DataMaskPolicy> dataMaskPolicy = accessManager.evalDataMaskPolicy(
currentUserIdentity, ctlName, dbName, tableName, slot.getName());
if (dataMaskPolicy.isPresent()) {
Expression unboundExpr = nereidsParser.parseExpression(dataMaskPolicy.get().getMaskTypeDef());
Expression childOfAlias
= unboundExpr instanceof UnboundAlias ? unboundExpr.child(0) : unboundExpr;
Alias alias = new Alias(
StatementScopeIdGenerator.newExprId(),
ImmutableList.of(childOfAlias),
slot.getName(), slot.getQualifier(), false
);
dataMasks.add(alias);
hasDataMask = true;
} else {
dataMasks.add(slot);
}
if (sqlCacheContext.isPresent()) {
sqlCacheContext.get().addDataMaskPolicy(ctlName, dbName, tableName, slot.getName(), dataMaskPolicy);
}
}
List<? extends RowFilterPolicy> rowPolicies = accessManager.evalRowFilterPolicies(
currentUserIdentity, ctlName, dbName, tableName);
if (sqlCacheContext.isPresent()) {
sqlCacheContext.get().setRowFilterPolicy(ctlName, dbName, tableName, rowPolicies);
}
return new RelatedPolicy(
Optional.ofNullable(CollectionUtils.isEmpty(rowPolicies) ? null : mergeRowPolicy(rowPolicies)),
hasDataMask ? Optional.of(dataMasks.build()) : Optional.empty()
);
}
private Expression mergeRowPolicy(List<? extends RowFilterPolicy> policies) {
List<Expression> orList = new ArrayList<>();
List<Expression> andList = new ArrayList<>();
for (RowFilterPolicy policy : policies) {
Expression wherePredicate = null;
try {
wherePredicate = policy.getFilterExpression();
} catch (org.apache.doris.common.AnalysisException e) {
throw new AnalysisException(e.getMessage(), e);
}
switch (policy.getFilterType()) {
case PERMISSIVE:
orList.add(wherePredicate);
break;
case RESTRICTIVE:
andList.add(wherePredicate);
break;
default:
throw new IllegalStateException("Invalid operator");
}
}
if (!andList.isEmpty() && !orList.isEmpty()) {
return new And(ExpressionUtils.and(andList), ExpressionUtils.or(orList));
} else if (andList.isEmpty()) {
return ExpressionUtils.or(orList);
} else if (orList.isEmpty()) {
return ExpressionUtils.and(andList);
} else {
return null;
}
}
/** RelatedPolicy */
public static class RelatedPolicy {
public static final RelatedPolicy NO_POLICY = new RelatedPolicy(Optional.empty(), Optional.empty());
public final Optional<Expression> rowPolicyFilter;
public final Optional<List<NamedExpression>> dataMaskProjects;
public RelatedPolicy(Optional<Expression> rowPolicyFilter, Optional<List<NamedExpression>> dataMaskProjects) {
this.rowPolicyFilter = rowPolicyFilter;
this.dataMaskProjects = dataMaskProjects;
}
}
}