AnalyzeCTE.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CTEContext;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.CTEId;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Nullable;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTE;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRecursiveCte;
import org.apache.doris.nereids.trees.plans.logical.LogicalRecursiveCteRecursiveChild;
import org.apache.doris.nereids.trees.plans.logical.LogicalRecursiveCteScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.trees.plans.logical.ProjectProcessor;
import org.apache.doris.nereids.types.DataType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Register CTE, includes checking columnAliases, checking CTE name, analyzing each CTE and store the
* analyzed logicalPlan of CTE's query in CTEContext;
* A LogicalProject node will be added to the root of the initial logicalPlan if there exist columnAliases.
* Node LogicalCTE will be eliminated after registering.
*/
public class AnalyzeCTE extends OneAnalysisRuleFactory {
@Override
public Rule build() {
return logicalCTE().thenApply(ctx -> {
LogicalCTE<Plan> logicalCTE = ctx.root;
// step 0. check duplicate cte name
Set<String> uniqueAlias = Sets.newHashSet();
List<String> aliases = logicalCTE.getAliasQueries().stream()
.map(LogicalSubQueryAlias::getAlias)
.collect(Collectors.toList());
for (String alias : aliases) {
if (uniqueAlias.contains(alias)) {
throw new AnalysisException("CTE name [" + alias + "] cannot be used more than once.");
}
uniqueAlias.add(alias);
}
// step 1. analyzed all cte plan
Pair<CTEContext, List<LogicalCTEProducer<Plan>>> result = analyzeCte(logicalCTE, ctx.cascadesContext);
CascadesContext outerCascadesCtx = CascadesContext.newContextWithCteContext(
ctx.cascadesContext, logicalCTE.child(), result.first, Optional.empty(), ImmutableList.of());
outerCascadesCtx.withPlanProcess(ctx.cascadesContext.showPlanProcess(), () -> {
outerCascadesCtx.newAnalyzer().analyze();
});
ctx.cascadesContext.setLeadingDisableJoinReorder(outerCascadesCtx.isLeadingDisableJoinReorder());
Plan root = outerCascadesCtx.getRewritePlan();
ctx.cascadesContext.addPlanProcesses(outerCascadesCtx.getPlanProcesses());
// should construct anchor from back to front, because the cte behind depends on the front
for (int i = result.second.size() - 1; i >= 0; i--) {
root = new LogicalCTEAnchor<>(result.second.get(i).getCteId(), result.second.get(i), root);
}
return root;
}).toRule(RuleType.ANALYZE_CTE);
}
/**
* register and store CTEs in CTEContext
*/
private Pair<CTEContext, List<LogicalCTEProducer<Plan>>> analyzeCte(
LogicalCTE<Plan> logicalCTE, CascadesContext cascadesContext) {
CTEContext outerCteCtx = cascadesContext.getCteContext();
List<LogicalSubQueryAlias<Plan>> aliasQueries = logicalCTE.getAliasQueries();
List<LogicalCTEProducer<Plan>> cteProducerPlans = new ArrayList<>();
for (LogicalSubQueryAlias<Plan> aliasQuery : aliasQueries) {
// we should use a chain to ensure visible of cte
if (aliasQuery.isRecursiveCte() && logicalCTE.isRecursiveCte()) {
Pair<CTEContext, LogicalCTEProducer<Plan>> result = analyzeRecursiveCte(aliasQuery, outerCteCtx,
cascadesContext);
outerCteCtx = result.first;
cteProducerPlans.add(result.second);
} else {
LogicalPlan parsedCtePlan = (LogicalPlan) aliasQuery.child();
CascadesContext innerCascadesCtx = CascadesContext.newContextWithCteContext(
cascadesContext, parsedCtePlan, outerCteCtx, Optional.empty(), ImmutableList.of());
innerCascadesCtx.withPlanProcess(cascadesContext.showPlanProcess(), () -> {
innerCascadesCtx.newAnalyzer().analyze();
});
cascadesContext.addPlanProcesses(innerCascadesCtx.getPlanProcesses());
LogicalPlan analyzedCtePlan = (LogicalPlan) innerCascadesCtx.getRewritePlan();
checkColumnAlias(aliasQuery, analyzedCtePlan.getOutput());
CTEId cteId = StatementScopeIdGenerator.newCTEId();
LogicalSubQueryAlias<Plan> logicalSubQueryAlias = aliasQuery
.withChildren(ImmutableList.of(analyzedCtePlan));
outerCteCtx = new CTEContext(cteId, logicalSubQueryAlias, outerCteCtx);
outerCteCtx.setAnalyzedPlan(logicalSubQueryAlias);
cteProducerPlans.add(new LogicalCTEProducer<>(cteId, logicalSubQueryAlias));
}
}
return Pair.of(outerCteCtx, cteProducerPlans);
}
private Pair<CTEContext, LogicalCTEProducer<Plan>> analyzeRecursiveCte(LogicalSubQueryAlias<Plan> aliasQuery,
CTEContext outerCteCtx, CascadesContext cascadesContext) {
Preconditions.checkArgument(aliasQuery.isRecursiveCte(), "alias query must be recursive cte");
LogicalPlan parsedCtePlan = (LogicalPlan) aliasQuery.child();
if (!(parsedCtePlan instanceof LogicalUnion) || parsedCtePlan.children().size() != 2) {
throw new AnalysisException(String.format("recursive cte must be union, don't support %s",
parsedCtePlan.getClass().getSimpleName()));
}
// analyze anchor child, its output list will be recursive cte temp table's schema
LogicalPlan anchorChild = (LogicalPlan) parsedCtePlan.child(0);
CascadesContext innerAnchorCascadesCtx = CascadesContext.newContextWithCteContext(
cascadesContext, anchorChild, outerCteCtx, Optional.of(aliasQuery.getAlias()), ImmutableList.of());
innerAnchorCascadesCtx.withPlanProcess(cascadesContext.showPlanProcess(), () -> {
innerAnchorCascadesCtx.newAnalyzer().analyze();
});
cascadesContext.addPlanProcesses(innerAnchorCascadesCtx.getPlanProcesses());
LogicalPlan analyzedAnchorChild = (LogicalPlan) innerAnchorCascadesCtx.getRewritePlan();
Set<LogicalRecursiveCteScan> recursiveCteScans = analyzedAnchorChild
.collect(LogicalRecursiveCteScan.class::isInstance);
for (LogicalRecursiveCteScan cteScan : recursiveCteScans) {
if (cteScan.getTable().getName().equalsIgnoreCase(aliasQuery.getAlias())) {
throw new AnalysisException(
String.format("recursive reference to query %s must not appear within its non-recursive term",
aliasQuery.getAlias()));
}
}
checkColumnAlias(aliasQuery, analyzedAnchorChild.getOutput());
// make all output nullable
analyzedAnchorChild = forceOutputNullable(analyzedAnchorChild,
aliasQuery.getColumnAliases().orElse(ImmutableList.of()));
// analyze recursive child
LogicalPlan recursiveChild = (LogicalPlan) parsedCtePlan.child(1);
CascadesContext innerRecursiveCascadesCtx = CascadesContext.newContextWithCteContext(
cascadesContext, recursiveChild, outerCteCtx, Optional.of(aliasQuery.getAlias()),
analyzedAnchorChild.getOutput());
innerRecursiveCascadesCtx.withPlanProcess(cascadesContext.showPlanProcess(), () -> {
innerRecursiveCascadesCtx.newAnalyzer().analyze();
});
cascadesContext.addPlanProcesses(innerRecursiveCascadesCtx.getPlanProcesses());
LogicalPlan analyzedRecursiveChild = (LogicalPlan) innerRecursiveCascadesCtx.getRewritePlan();
List<LogicalRecursiveCteScan> recursiveCteScanList = analyzedRecursiveChild
.collectToList(LogicalRecursiveCteScan.class::isInstance);
if (recursiveCteScanList.size() > 1) {
throw new AnalysisException(String.format("recursive reference to query %s must not appear more than once",
aliasQuery.getAlias()));
}
List<Slot> anchorChildOutputs = analyzedAnchorChild.getOutput();
List<DataType> anchorChildOutputTypes = new ArrayList<>(anchorChildOutputs.size());
for (Slot slot : anchorChildOutputs) {
anchorChildOutputTypes.add(slot.getDataType());
}
List<Slot> recursiveChildOutputs = analyzedRecursiveChild.getOutput();
for (int i = 0; i < recursiveChildOutputs.size(); ++i) {
if (!recursiveChildOutputs.get(i).getDataType().equals(anchorChildOutputTypes.get(i))) {
throw new AnalysisException(String.format("%s recursive child's %d column's datatype in select list %s "
+ "is different from anchor child's output datatype %s, please add cast manually "
+ "to get expect datatype", aliasQuery.getAlias(), i + 1,
recursiveChildOutputs.get(i).getDataType(), anchorChildOutputTypes.get(i)));
}
}
analyzedRecursiveChild = new LogicalRecursiveCteRecursiveChild<>(aliasQuery.getAlias(),
forceOutputNullable(analyzedRecursiveChild, ImmutableList.of()));
// create LogicalRecursiveCte
LogicalUnion logicalUnion = (LogicalUnion) parsedCtePlan;
LogicalRecursiveCte analyzedCtePlan = new LogicalRecursiveCte(aliasQuery.getAlias(),
logicalUnion.getQualifier() == SetOperation.Qualifier.ALL,
ImmutableList.of(analyzedAnchorChild, analyzedRecursiveChild));
List<List<NamedExpression>> childrenProjections = analyzedCtePlan.collectChildrenProjections();
int childrenProjectionSize = childrenProjections.size();
ImmutableList.Builder<List<SlotReference>> childrenOutputs = ImmutableList
.builderWithExpectedSize(childrenProjectionSize);
ImmutableList.Builder<Plan> newChildren = ImmutableList.builderWithExpectedSize(childrenProjectionSize);
for (int i = 0; i < childrenProjectionSize; i++) {
Plan newChild;
Plan child = analyzedCtePlan.child(i);
if (childrenProjections.get(i).stream().allMatch(SlotReference.class::isInstance)) {
newChild = child;
} else {
List<NamedExpression> parentProject = childrenProjections.get(i);
newChild = ProjectProcessor.tryProcessProject(parentProject, child)
.orElseGet(() -> new LogicalProject<>(parentProject, child));
}
newChildren.add(newChild);
childrenOutputs.add((List<SlotReference>) (List) newChild.getOutput());
}
analyzedCtePlan = analyzedCtePlan.withChildrenAndTheirOutputs(newChildren.build(), childrenOutputs.build());
List<NamedExpression> newOutputs = analyzedCtePlan.buildNewOutputs();
analyzedCtePlan = analyzedCtePlan.withNewOutputs(newOutputs);
CTEId cteId = StatementScopeIdGenerator.newCTEId();
LogicalSubQueryAlias<Plan> logicalSubQueryAlias = aliasQuery.withChildren(ImmutableList.of(analyzedCtePlan));
outerCteCtx = new CTEContext(cteId, logicalSubQueryAlias, outerCteCtx);
outerCteCtx.setAnalyzedPlan(logicalSubQueryAlias);
LogicalCTEProducer<Plan> cteProducer = new LogicalCTEProducer<>(cteId, logicalSubQueryAlias);
return Pair.of(outerCteCtx, cteProducer);
}
private LogicalPlan forceOutputNullable(LogicalPlan logicalPlan, List<String> aliasNames) {
List<Slot> oldOutputs = logicalPlan.getOutput();
int size = oldOutputs.size();
List<NamedExpression> newOutputs = new ArrayList<>(oldOutputs.size());
if (!aliasNames.isEmpty()) {
for (int i = 0; i < size; ++i) {
newOutputs.add(new Alias(new Nullable(oldOutputs.get(i)), aliasNames.get(i)));
}
} else {
for (Slot slot : oldOutputs) {
newOutputs.add(new Alias(new Nullable(slot), slot.getName()));
}
}
return new LogicalProject<>(newOutputs, logicalPlan);
}
/**
* check columnAliases' size and name
*/
private void checkColumnAlias(LogicalSubQueryAlias<Plan> aliasQuery, List<Slot> outputSlots) {
if (aliasQuery.getColumnAliases().isPresent()) {
List<String> columnAlias = aliasQuery.getColumnAliases().get();
// if the size of columnAlias is smaller than outputSlots' size, we will replace the corresponding number
// of front slots with columnAlias.
if (columnAlias.size() > outputSlots.size()) {
throw new AnalysisException("CTE [" + aliasQuery.getAlias() + "] returns "
+ columnAlias.size() + " columns, but " + outputSlots.size() + " labels were specified."
+ " The number of column labels must be smaller or equal to the number of returned columns.");
}
Set<String> names = new HashSet<>();
// column alias cannot be used more than once
columnAlias.forEach(alias -> {
if (names.contains(alias.toLowerCase())) {
throw new AnalysisException("Duplicated CTE column alias:"
+ " [" + alias.toLowerCase() + "] in CTE [" + aliasQuery.getAlias() + "]");
}
names.add(alias);
});
}
}
}