MergeIntoCommand.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.plans.commands.merge;
import org.apache.doris.analysis.ColumnDef.DefaultValue;
import org.apache.doris.analysis.StmtType;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.TableIf;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.analyzer.UnboundAlias;
import org.apache.doris.nereids.analyzer.UnboundRelation;
import org.apache.doris.nereids.analyzer.UnboundSlot;
import org.apache.doris.nereids.analyzer.UnboundStar;
import org.apache.doris.nereids.analyzer.UnboundTableSinkCreator;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.parser.LogicalPlanBuilderAssistant;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.rules.exploration.join.JoinReorderContext;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.DefaultValueSlot;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Now;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.plans.Explainable;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.commands.Command;
import org.apache.doris.nereids.trees.plans.commands.ForwardWithSync;
import org.apache.doris.nereids.trees.plans.commands.UpdateCommand;
import org.apache.doris.nereids.trees.plans.commands.info.DMLCommandType;
import org.apache.doris.nereids.trees.plans.commands.insert.InsertIntoTableCommand;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.util.RelationUtil;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.StmtExecutor;
import org.apache.doris.thrift.TPartialUpdateNewRowPolicy;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
/**
* merge into table
*/
public class MergeIntoCommand extends Command implements ForwardWithSync, Explainable {
private static final String BRANCH_LABEL = "__DORIS_MERGE_INTO_BRANCH_LABEL__";
private final List<String> targetNameParts;
private final Optional<String> targetAlias;
private final List<String> targetNameInPlan;
private final Optional<LogicalPlan> cte;
private final LogicalPlan source;
private final Expression onClause;
private final List<MergeMatchedClause> matchedClauses;
private final List<MergeNotMatchedClause> notMatchedClauses;
/**
* constructor.
*/
public MergeIntoCommand(List<String> targetNameParts, Optional<String> targetAlias,
Optional<LogicalPlan> cte, LogicalPlan source,
Expression onClause, List<MergeMatchedClause> matchedClauses,
List<MergeNotMatchedClause> notMatchedClauses) {
super(PlanType.MERGE_INTO_COMMAND);
this.targetNameParts = Utils.fastToImmutableList(
Objects.requireNonNull(targetNameParts, "targetNameParts should not be null"));
this.targetAlias = Objects.requireNonNull(targetAlias, "targetAlias should not be null");
if (targetAlias.isPresent()) {
this.targetNameInPlan = ImmutableList.of(targetAlias.get());
} else {
this.targetNameInPlan = ImmutableList.copyOf(targetNameParts);
}
this.cte = Objects.requireNonNull(cte, "cte should not be null");
this.source = Objects.requireNonNull(source, "source should not be null");
this.onClause = Objects.requireNonNull(onClause, "onClause should not be null");
this.matchedClauses = Utils.fastToImmutableList(
Objects.requireNonNull(matchedClauses, "matchedClauses should not be null"));
this.notMatchedClauses = Utils.fastToImmutableList(
Objects.requireNonNull(notMatchedClauses, "notMatchedClauses should not be null"));
}
@Override
public void run(ConnectContext ctx, StmtExecutor executor) throws Exception {
new InsertIntoTableCommand(completeQueryPlan(ctx), Optional.empty(), Optional.empty(),
Optional.empty(), true, Optional.empty()).run(ctx, executor);
}
@Override
public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
return visitor.visitMergeIntoCommand(this, context);
}
@Override
public Plan getExplainPlan(ConnectContext ctx) {
return completeQueryPlan(ctx);
}
private OlapTable getTargetTable(ConnectContext ctx) {
List<String> qualifiedTableName = RelationUtil.getQualifierName(ctx, targetNameParts);
TableIf table = RelationUtil.getTable(qualifiedTableName, ctx.getEnv(), Optional.empty());
if (!(table instanceof OlapTable) || !((OlapTable) table).getEnableUniqueKeyMergeOnWrite()) {
throw new AnalysisException("merge into command only support MOW unique key olapTable");
}
return ((OlapTable) table);
}
@Override
public StmtType stmtType() {
return StmtType.MERGE_INTO;
}
/**
* generate target right outer join source.
*/
private LogicalPlan generateBasePlan() {
LogicalPlan plan = LogicalPlanBuilderAssistant.withCheckPolicy(
new UnboundRelation(
StatementScopeIdGenerator.newRelationId(),
targetNameParts
)
);
if (targetAlias.isPresent()) {
plan = new LogicalSubQueryAlias<>(targetAlias.get(), plan);
}
return new LogicalJoin<>(JoinType.LEFT_OUTER_JOIN,
ImmutableList.of(), ImmutableList.of(onClause),
source, plan, JoinReorderContext.EMPTY);
}
/**
* generate a branch number column to indicate this row matched witch branch
*/
private NamedExpression generateBranchLabel() {
Expression matchedLabel = new NullLiteral(IntegerType.INSTANCE);
for (int i = matchedClauses.size() - 1; i >= 0; i--) {
MergeMatchedClause clause = matchedClauses.get(i);
if (i != matchedClauses.size() - 1 && !clause.getCasePredicate().isPresent()) {
throw new AnalysisException("Only the last matched clause could without case predicate.");
}
Expression currentResult = new IntegerLiteral(i);
if (clause.getCasePredicate().isPresent()) {
matchedLabel = new If(clause.getCasePredicate().get(), currentResult, matchedLabel);
} else {
matchedLabel = currentResult;
}
}
Expression notMatchedLabel = new NullLiteral(IntegerType.INSTANCE);
for (int i = notMatchedClauses.size() - 1; i >= 0; i--) {
MergeNotMatchedClause clause = notMatchedClauses.get(i);
if (i != notMatchedClauses.size() - 1 && !clause.getCasePredicate().isPresent()) {
throw new AnalysisException("Only the last not matched clause could without case predicate.");
}
Expression currentResult = new IntegerLiteral(i + matchedClauses.size());
if (clause.getCasePredicate().isPresent()) {
notMatchedLabel = new If(clause.getCasePredicate().get(), currentResult, notMatchedLabel);
} else {
notMatchedLabel = currentResult;
}
}
return new UnboundAlias(new If(onClause, matchedLabel, notMatchedLabel), BRANCH_LABEL);
}
private List<Expression> generateDeleteProjection(List<Column> columns) {
ImmutableList.Builder<Expression> builder = ImmutableList.builder();
for (Column column : columns) {
// delete
if (column.isDeleteSignColumn()) {
builder.add(new TinyIntLiteral(((byte) 1)));
} else if ((!column.isVisible() && !column.isSequenceColumn()) || column.isGeneratedColumn()) {
// skip this column
continue;
} else {
List<String> nameParts = Lists.newArrayList(targetNameInPlan);
nameParts.add(column.getName());
builder.add(new UnboundSlot(nameParts));
}
}
return builder.build();
}
private List<Expression> generateUpdateProjection(MergeMatchedClause clause,
List<Column> columns, OlapTable targetTable, ConnectContext ctx) {
ImmutableList.Builder<Expression> builder = ImmutableList.builder();
Map<String, Expression> colNameToExpression = Maps.newTreeMap(String.CASE_INSENSITIVE_ORDER);
// update
for (EqualTo equalTo : clause.getAssignments()) {
List<String> nameParts = ((UnboundSlot) equalTo.left()).getNameParts();
UpdateCommand.checkAssignmentColumn(ctx, nameParts, targetNameParts, targetAlias.orElse(null));
if (colNameToExpression.put(nameParts.get(nameParts.size() - 1), equalTo.right()) != null) {
throw new AnalysisException("Duplicate column name in update: " + nameParts.get(nameParts.size() - 1));
}
}
for (Column column : columns) {
DataType dataType = DataType.fromCatalogType(column.getType());
if (colNameToExpression.containsKey(column.getName())) {
if (column.isKey()) {
throw new AnalysisException("Only value columns of unique table could be updated");
}
if (column.isGeneratedColumn()) {
throw new AnalysisException("The value specified for generated column '"
+ column.getName() + "' in table '" + targetTable.getName() + "' is not allowed.");
}
builder.add(new Cast(colNameToExpression.get(column.getName()), dataType));
colNameToExpression.remove(column.getName());
} else if (column.isGeneratedColumn() || (!column.isVisible()
&& !column.isDeleteSignColumn() && !column.isSequenceColumn())) {
// skip these columns
continue;
} else if (column.hasOnUpdateDefaultValue()) {
builder.add(new Cast(new NereidsParser().parseExpression(
column.getOnUpdateDefaultValueExpr().toSqlWithoutTbl()), dataType));
} else {
List<String> nameParts = Lists.newArrayList(targetNameInPlan);
nameParts.add(column.getName());
builder.add(new Cast(new UnboundSlot(nameParts), dataType));
}
}
if (!colNameToExpression.isEmpty()) {
throw new AnalysisException("unknown column in assignment list: "
+ String.join(", ", colNameToExpression.keySet()));
}
return builder.build();
}
private List<Expression> generateInsertWithoutColListProjection(MergeNotMatchedClause clause,
List<Column> columns, OlapTable targetTable, boolean hasSequenceCol, int seqColumnIndex,
Optional<Column> seqMappingColInTable, Optional<Type> seqColType) {
ImmutableList.Builder<Expression> builder = ImmutableList.builder();
if (hasSequenceCol && seqColumnIndex < 0) {
if ((!seqMappingColInTable.isPresent() || seqMappingColInTable.get().getDefaultValue() == null
|| !seqMappingColInTable.get().getDefaultValue()
.equalsIgnoreCase(DefaultValue.CURRENT_TIMESTAMP))) {
throw new AnalysisException("Table " + targetTable.getName()
+ " has sequence column, need to specify the sequence column");
}
}
Expression sqlColExpr = new Now();
for (int i = 0; i < clause.getRow().size(); i++) {
DataType columnType = DataType.fromCatalogType(columns.get(i).getType());
NamedExpression rowItem = clause.getRow().get(i);
Expression value;
if (rowItem instanceof Alias || rowItem instanceof UnboundAlias) {
value = rowItem.child(0);
} else {
value = rowItem;
}
if (columns.get(i).isGeneratedColumn()) {
if (!(value instanceof DefaultValueSlot)) {
throw new AnalysisException("The value specified for generated column '"
+ columns.get(i).getName()
+ "' in table '" + targetTable.getName() + "' is not allowed.");
}
continue;
}
value = new Cast(value, columnType);
if (i == seqColumnIndex) {
sqlColExpr = value;
}
builder.add(value);
}
// delete sign
builder.add(new TinyIntLiteral(((byte) 0)));
// sequence column
if (hasSequenceCol) {
builder.add(new Cast(sqlColExpr, seqColType.map(DataType::fromCatalogType).get()));
}
return builder.build();
}
private List<Expression> generateInsertWithColListProjection(MergeNotMatchedClause clause,
List<Column> columns, OlapTable targetTable, boolean hasSequenceCol,
String seqColumnName, Optional<Column> seqMappingColInTable, Optional<Type> seqColType) {
ImmutableList.Builder<Expression> builder = ImmutableList.builder();
Map<String, Expression> colNameToExpression = Maps.newTreeMap(String.CASE_INSENSITIVE_ORDER);
for (int i = 0; i < clause.getColNames().size(); i++) {
String targetColumnName = clause.getColNames().get(i);
NamedExpression rowItem = clause.getRow().get(i);
if (rowItem instanceof Alias || rowItem instanceof UnboundAlias) {
colNameToExpression.put(targetColumnName, rowItem.child(0));
} else {
colNameToExpression.put(targetColumnName, rowItem);
}
}
if (colNameToExpression.size() != clause.getColNames().size()) {
throw new AnalysisException("insert has duplicate column names");
}
if (hasSequenceCol) {
if (seqColumnName == null || seqColumnName.isEmpty()) {
seqColumnName = Column.SEQUENCE_COL;
}
if (!colNameToExpression.containsKey(seqColumnName)
&& (!seqMappingColInTable.isPresent() || seqMappingColInTable.get().getDefaultValue() == null
|| !seqMappingColInTable.get().getDefaultValue()
.equalsIgnoreCase(DefaultValue.CURRENT_TIMESTAMP))) {
throw new AnalysisException("Table " + targetTable.getName()
+ " has sequence column, need to specify the sequence column");
}
}
for (Column column : columns) {
DataType type = DataType.fromCatalogType(column.getType());
if (column.isGeneratedColumn()) {
if (colNameToExpression.containsKey(column.getName())) {
if (!(colNameToExpression.get(column.getName()) instanceof DefaultValueSlot)) {
throw new AnalysisException("The value specified for generated column '"
+ column.getName() + "' in table '" + targetTable.getName() + "' is not allowed.");
}
colNameToExpression.remove(column.getName());
}
continue;
} else if (!column.isVisible()) {
// skip these columns
continue;
} else if (colNameToExpression.containsKey(column.getName())) {
builder.add(new Cast(colNameToExpression.get(column.getName()), type));
if (!column.getName().equalsIgnoreCase(seqColumnName)) {
colNameToExpression.remove(column.getName());
}
} else {
if (!column.hasDefaultValue()) {
if (!column.isAllowNull() && !column.isAutoInc()) {
throw new AnalysisException("Column has no default value,"
+ " column=" + column.getName());
}
builder.add(new NullLiteral(type));
} else {
Expression defaultExpr;
try {
// it comes from the original planner, if default value expression is
// null, we use the literal string of the default value, or it may be
// default value function, like CURRENT_TIMESTAMP.
if (column.getDefaultValueExpr() == null) {
defaultExpr = Literal.of(column.getDefaultValue()).checkedCastWithFallback(type);
} else {
Expression unboundDefaultValue = new NereidsParser().parseExpression(
column.getDefaultValueExpr().toSqlWithoutTbl());
if (unboundDefaultValue instanceof UnboundAlias) {
unboundDefaultValue = ((UnboundAlias) unboundDefaultValue).child();
}
defaultExpr = new Cast(unboundDefaultValue, type);
}
} catch (Exception e) {
throw new AnalysisException(e.getMessage(), e.getCause());
}
builder.add(defaultExpr);
}
}
}
builder.add(colNameToExpression.getOrDefault(Column.DELETE_SIGN, new TinyIntLiteral(((byte) 0))));
colNameToExpression.remove(Column.DELETE_SIGN);
if (hasSequenceCol) {
Expression forSeqCol;
if (colNameToExpression.containsKey(Column.SEQUENCE_COL)) {
forSeqCol = colNameToExpression.get(Column.SEQUENCE_COL);
colNameToExpression.remove(Column.SEQUENCE_COL);
colNameToExpression.remove(seqColumnName);
} else if (colNameToExpression.containsKey(seqColumnName)) {
forSeqCol = colNameToExpression.get(seqColumnName);
colNameToExpression.remove(seqColumnName);
} else {
forSeqCol = new Now();
}
builder.add(new Cast(forSeqCol, seqColType.map(DataType::fromCatalogType).get()));
}
if (!colNameToExpression.isEmpty()) {
throw new AnalysisException("unknown column in target table: "
+ String.join(", ", colNameToExpression.keySet()));
}
return builder.build();
}
private List<NamedExpression> generateFinalProjections(List<String> colNames,
List<List<Expression>> finalProjections) {
for (List<Expression> projection : finalProjections) {
if (projection.size() != finalProjections.get(0).size()) {
throw new AnalysisException("Column count doesn't match each other");
}
}
ImmutableList.Builder<NamedExpression> outputProjectionsBuilder = ImmutableList.builder();
for (int i = 0; i < finalProjections.get(0).size(); i++) {
Expression project = new NullLiteral();
for (int j = 0; j < finalProjections.size(); j++) {
project = new If(new EqualTo(new UnboundSlot(BRANCH_LABEL), new IntegerLiteral(j)),
finalProjections.get(j).get(i), project);
}
outputProjectionsBuilder.add(new UnboundAlias(project, colNames.get(i)));
}
return outputProjectionsBuilder.build();
}
/**
* complete merge into plan.
*/
private LogicalPlan completeQueryPlan(ConnectContext ctx) {
// check insert include all keys
OlapTable targetTable = getTargetTable(ctx);
List<Column> columns = targetTable.getBaseSchema(true);
// compute sequence column info
boolean hasSequenceCol = targetTable.hasSequenceCol();
String seqColName = null;
int seqColumnIndex = -1;
Optional<Column> seqMappingColInTable = Optional.empty();
if (hasSequenceCol) {
seqColName = targetTable.getSequenceMapCol();
String finalSeqColName = seqColName;
if (seqColName != null) {
for (int i = 0; i < columns.size(); i++) {
Column column = columns.get(i);
if (column.getName().equalsIgnoreCase(seqColName)) {
seqColumnIndex = i;
break;
}
}
seqMappingColInTable = columns.stream()
.filter(col -> col.getName().equalsIgnoreCase(finalSeqColName))
.findFirst();
}
}
if (seqColumnIndex != -1 && !seqMappingColInTable.isPresent()) {
throw new AnalysisException("sequence column is not contained in"
+ " target table " + targetTable.getName());
}
// generate base plan
LogicalPlan plan = generateBasePlan();
// generate a project to add delete sign, seq column, label and mark
ImmutableList.Builder<NamedExpression> outputProjections = ImmutableList.builder();
outputProjections.add(new UnboundStar(ImmutableList.of()));
outputProjections.add(generateBranchLabel());
List<String> targetDeleteSignNameParts = Lists.newArrayList(targetNameInPlan);
targetDeleteSignNameParts.add(Column.DELETE_SIGN);
NamedExpression deleteSign = new UnboundSlot(targetDeleteSignNameParts);
outputProjections.add(deleteSign);
if (hasSequenceCol) {
List<String> targetSeqColNameParts = Lists.newArrayList(targetNameInPlan);
targetSeqColNameParts.add(Column.SEQUENCE_COL);
NamedExpression seqCol = new UnboundSlot(targetSeqColNameParts);
outputProjections.add(seqCol);
}
plan = new LogicalProject<>(outputProjections.build(), plan);
// remove all lines that do not be used for update, delete and insert
plan = new LogicalFilter<>(ImmutableSet.of(new Not(new IsNull(new UnboundSlot(BRANCH_LABEL)))), plan);
// compute final project by branch number and add delete sign
List<List<Expression>> finalProjections = Lists.newArrayList();
// matched
for (MergeMatchedClause clause : matchedClauses) {
if (clause.isDelete()) {
finalProjections.add(generateDeleteProjection(columns));
} else {
finalProjections.add(generateUpdateProjection(clause, columns, targetTable, ctx));
}
}
// not matched
long columnCount = columns.stream().filter(Column::isVisible).count();
for (MergeNotMatchedClause clause : notMatchedClauses) {
if (clause.getColNames().isEmpty()) {
if (columnCount != clause.getRow().size()) {
throw new AnalysisException("Column count doesn't match value count");
}
finalProjections.add(generateInsertWithoutColListProjection(clause, columns, targetTable,
hasSequenceCol, seqColumnIndex, seqMappingColInTable,
Optional.ofNullable(targetTable.getSequenceType())));
} else {
if (clause.getColNames().size() != clause.getRow().size()) {
throw new AnalysisException("Column count doesn't match value count");
}
finalProjections.add(generateInsertWithColListProjection(clause, columns, targetTable,
hasSequenceCol, seqColName, seqMappingColInTable,
Optional.ofNullable(targetTable.getSequenceType())));
}
}
List<String> colNames = columns.stream()
.filter(c -> (c.isVisible() && !c.isGeneratedColumn())
|| c.isDeleteSignColumn() || c.isSequenceColumn())
.map(Column::getName)
.collect(ImmutableList.toImmutableList());
plan = new LogicalProject<>(generateFinalProjections(colNames, finalProjections), plan);
// TODO 6, 7, 8, 9 strict mode
// 6. add a set of new columns used for group by: if(mark = 1, target keys + mark, insert keys + mark)
// 7. add window node, partition by group by key, order by 1, row number, count(update) as uc, max(delete) as dc
// 8. get row_number = 1
// 9. assert_true(uc <= 1 and (uc = 0 || dc = 0) (optional)
if (cte.isPresent()) {
plan = (LogicalPlan) cte.get().withChildren(plan);
}
plan = UnboundTableSinkCreator.createUnboundTableSink(targetNameParts, colNames, ImmutableList.of(),
false, ImmutableList.of(), false, TPartialUpdateNewRowPolicy.APPEND,
DMLCommandType.INSERT, plan);
return plan;
}
}