UpdateCommand.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.plans.commands;
import org.apache.doris.analysis.StmtType;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.KeysType;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.Table;
import org.apache.doris.catalog.TableIf;
import org.apache.doris.nereids.analyzer.UnboundAlias;
import org.apache.doris.nereids.analyzer.UnboundSlot;
import org.apache.doris.nereids.analyzer.UnboundTableSinkCreator;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.rules.analysis.ExpressionAnalyzer;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.plans.Explainable;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.commands.info.DMLCommandType;
import org.apache.doris.nereids.trees.plans.commands.insert.InsertIntoTableCommand;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.RelationUtil;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.StmtExecutor;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import javax.annotation.Nullable;
/**
* update command
* the two case will be handled as:
* case 1:
* update table t1 set v1 = v1 + 1 where k1 = 1 and k2 = 2;
* =>
* insert into table (v1) select v1 + 1 from table t1 where k1 = 1 and k2 = 2
* case 2:
* update t1 set t1.c1 = t2.c1, t1.c3 = t2.c3 * 100
* from t2 inner join t3 on t2.id = t3.id
* where t1.id = t2.id;
* =>
* insert into t1 (c1, c3) select t2.c1, t2.c3 * 100 from t1 join t2 inner join t3 on t2.id = t3.id where t1.id = t2.id
*/
public class UpdateCommand extends Command implements ForwardWithSync, Explainable {
private final List<EqualTo> assignments;
private final List<String> nameParts;
private final @Nullable String tableAlias;
private final LogicalPlan logicalQuery;
private OlapTable targetTable;
private final Optional<LogicalPlan> cte;
/**
* constructor
*/
public UpdateCommand(List<String> nameParts, @Nullable String tableAlias, List<EqualTo> assignments,
LogicalPlan logicalQuery, Optional<LogicalPlan> cte) {
super(PlanType.UPDATE_COMMAND);
this.nameParts = Utils.copyRequiredList(nameParts);
this.assignments = Utils.copyRequiredList(assignments);
this.tableAlias = tableAlias;
this.logicalQuery = Objects.requireNonNull(logicalQuery, "logicalQuery is required in update command");
this.cte = cte;
}
@Override
public void run(ConnectContext ctx, StmtExecutor executor) throws Exception {
// NOTE: update command is executed as insert command, so txn insert can support it
new InsertIntoTableCommand(completeQueryPlan(ctx, logicalQuery), Optional.empty(), Optional.empty(),
Optional.empty()).run(ctx, executor);
}
/**
* add LogicalOlapTableSink node, public for test.
*/
@VisibleForTesting
public LogicalPlan completeQueryPlan(ConnectContext ctx, LogicalPlan logicalQuery) {
checkTable(ctx);
Map<String, Expression> colNameToExpression = Maps.newTreeMap(String.CASE_INSENSITIVE_ORDER);
Map<String, Expression> partialUpdateColNameToExpression = Maps.newTreeMap(String.CASE_INSENSITIVE_ORDER);
for (EqualTo equalTo : assignments) {
List<String> nameParts = ((UnboundSlot) equalTo.left()).getNameParts();
checkAssignmentColumn(ctx, nameParts);
colNameToExpression.put(nameParts.get(nameParts.size() - 1), equalTo.right());
partialUpdateColNameToExpression.put(nameParts.get(nameParts.size() - 1), equalTo.right());
}
// check if any key in update clause
if (targetTable.getFullSchema().stream().filter(Column::isKey)
.anyMatch(column -> partialUpdateColNameToExpression.containsKey(column.getName()))) {
throw new AnalysisException("Only value columns of unique table could be updated");
}
List<NamedExpression> selectItems = Lists.newArrayList();
String tableName = tableAlias != null ? tableAlias : targetTable.getName();
Expression setExpr = null;
for (Column column : targetTable.getFullSchema()) {
// if it sets sequence column in stream load phase, the sequence map column is null, we query it.
if (!column.isVisible() && !column.isSequenceColumn()) {
continue;
}
if (colNameToExpression.containsKey(column.getName())) {
Expression expr = colNameToExpression.get(column.getName());
// when updating the sequence map column, the real sequence column need to set with the same value.
boolean isSequenceMapColumn = targetTable.hasSequenceCol()
&& targetTable.getSequenceMapCol() != null
&& column.getName().equalsIgnoreCase(targetTable.getSequenceMapCol());
if (setExpr == null && isSequenceMapColumn) {
setExpr = expr;
}
selectItems.add(expr instanceof UnboundSlot
? ((NamedExpression) expr)
: new UnboundAlias(expr));
colNameToExpression.remove(column.getName());
} else {
if (column.isSequenceColumn() && setExpr != null) {
selectItems.add(new UnboundAlias(setExpr, column.getName()));
} else if (column.hasOnUpdateDefaultValue()) {
Expression defualtValueExpression =
new NereidsParser().parseExpression(column.getOnUpdateDefaultValueExpr()
.toSqlWithoutTbl());
selectItems.add(new UnboundAlias(defualtValueExpression, column.getName()));
} else {
selectItems.add(new UnboundSlot(tableName, column.getName()));
}
}
}
if (!colNameToExpression.isEmpty()) {
throw new AnalysisException("unknown column in assignment list: "
+ String.join(", ", colNameToExpression.keySet()));
}
boolean isPartialUpdate = targetTable.getEnableUniqueKeyMergeOnWrite()
&& selectItems.size() < targetTable.getColumns().size()
&& targetTable.getSequenceCol() == null
&& partialUpdateColNameToExpression.size() <= targetTable.getFullSchema().size() * 3 / 10
&& !targetTable.isUniqKeyMergeOnWriteWithClusterKeys();
List<String> partialUpdateColNames = new ArrayList<>();
List<NamedExpression> partialUpdateSelectItems = new ArrayList<>();
if (isPartialUpdate) {
for (Column column : targetTable.getFullSchema()) {
Expression expr = new UnboundSlot(tableName, column.getName());
boolean existInExpr = false;
for (String colName : partialUpdateColNameToExpression.keySet()) {
if (colName.equalsIgnoreCase(column.getName())) {
expr = partialUpdateColNameToExpression.get(column.getName());
existInExpr = true;
break;
}
}
if (column.isKey() || existInExpr) {
partialUpdateSelectItems.add(expr instanceof UnboundSlot
? ((NamedExpression) expr)
: new UnboundAlias(expr));
partialUpdateColNames.add(column.getName());
}
}
}
logicalQuery = new LogicalProject<>(isPartialUpdate ? partialUpdateSelectItems : selectItems, logicalQuery);
if (cte.isPresent()) {
logicalQuery = ((LogicalPlan) cte.get().withChildren(logicalQuery));
}
// make UnboundTableSink
return UnboundTableSinkCreator.createUnboundTableSink(nameParts,
isPartialUpdate ? partialUpdateColNames : ImmutableList.of(), ImmutableList.of(),
false, ImmutableList.of(), isPartialUpdate, DMLCommandType.UPDATE, logicalQuery);
}
private void checkAssignmentColumn(ConnectContext ctx, List<String> columnNameParts) {
if (columnNameParts.size() <= 1) {
return;
}
String dbName = null;
String tableName = null;
if (columnNameParts.size() == 3) {
dbName = columnNameParts.get(0);
tableName = columnNameParts.get(1);
} else if (columnNameParts.size() == 2) {
tableName = columnNameParts.get(0);
} else {
throw new AnalysisException("column in assignment list is invalid, " + String.join(".", columnNameParts));
}
if (dbName != null && this.tableAlias != null) {
throw new AnalysisException("column in assignment list is invalid, " + String.join(".", columnNameParts));
}
List<String> tableQualifier = RelationUtil.getQualifierName(ctx, nameParts);
if (!ExpressionAnalyzer.sameTableName(tableAlias == null ? tableQualifier.get(2) : tableAlias, tableName)
|| (dbName != null
&& !ExpressionAnalyzer.compareDbNameIgnoreClusterName(tableQualifier.get(1), dbName))) {
throw new AnalysisException("column in assignment list is invalid, " + String.join(".", columnNameParts));
}
}
private void checkTable(ConnectContext ctx) {
if (ctx.getSessionVariable().isInDebugMode()) {
throw new AnalysisException("Update is forbidden since current session is in debug mode."
+ " Please check the following session variables: "
+ ctx.getSessionVariable().printDebugModeVariables());
}
List<String> tableQualifier = RelationUtil.getQualifierName(ctx, nameParts);
TableIf table = RelationUtil.getTable(tableQualifier, ctx.getEnv());
if (!(table instanceof OlapTable)) {
throw new AnalysisException("target table in update command should be an olapTable");
}
targetTable = ((OlapTable) table);
if (targetTable.getType() != Table.TableType.OLAP
|| targetTable.getKeysType() != KeysType.UNIQUE_KEYS) {
throw new AnalysisException("Only unique table could be updated.");
}
}
@Override
public Plan getExplainPlan(ConnectContext ctx) {
return completeQueryPlan(ctx, logicalQuery);
}
public LogicalPlan getLogicalQuery() {
return logicalQuery;
}
@Override
public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
return visitor.visitUpdateCommand(this, context);
}
@Override
public StmtType stmtType() {
return StmtType.UPDATE;
}
}