StatsDerive.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.stats.StatsCalculator;
import org.apache.doris.nereids.trees.expressions.CTEId;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.AbstractPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalAssertNumRows;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalDeferMaterializeOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalDeferMaterializeTopN;
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalExcept;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate;
import org.apache.doris.nereids.trees.plans.logical.LogicalIntersect;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalPartitionTopN;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.logical.LogicalSink;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.logical.LogicalTVFRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.statistics.ColumnStatistic;
import org.apache.doris.statistics.Statistics;
import org.apache.doris.statistics.StatisticsBuilder;
import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
/**
* stats derive for rbo
*/
public class StatsDerive extends PlanVisitor<Statistics, StatsDerive.DeriveContext> implements CustomRewriter {
private final boolean deepDerive;
/**
* context
*/
public static class DeriveContext {
StatsCalculator calculator = new StatsCalculator(null);
}
public StatsDerive(boolean deepDerive) {
super();
this.deepDerive = deepDerive;
}
@Override
public Plan rewriteRoot(Plan plan, JobContext jobContext) {
plan.accept(this, new DeriveContext());
return plan;
}
@Override
public Statistics visit(Plan plan, DeriveContext context) {
Statistics result = ((AbstractPlan) plan).getStats();
if (result == null || deepDerive) {
for (int i = plan.children().size() - 1; i >= 0; i--) {
result = plan.children().get(i).accept(this, context);
}
((AbstractPlan) plan).setStatistics(result);
}
return result;
}
@Override
public Statistics visitLogicalProject(LogicalProject<? extends Plan> project, DeriveContext context) {
Statistics stats = project.getStats();
if (stats == null || deepDerive) {
Statistics childStats = project.child().accept(this, context);
stats = context.calculator.computeProject(project, childStats);
project.setStatistics(stats);
}
return stats;
}
@Override
public Statistics visitLogicalSink(LogicalSink<? extends Plan> logicalSink, DeriveContext context) {
Statistics childStats = logicalSink.child().accept(this, context);
logicalSink.setStatistics(childStats);
return childStats;
}
@Override
public Statistics visitLogicalEmptyRelation(LogicalEmptyRelation emptyRelation, DeriveContext context) {
Statistics stats = context.calculator.computeEmptyRelation(emptyRelation);
emptyRelation.setStatistics(stats);
return stats;
}
@Override
public Statistics visitLogicalLimit(LogicalLimit<? extends Plan> limit, DeriveContext context) {
Statistics childStats = limit.child().accept(this, context);
Statistics stats = context.calculator.computeLimit(limit, childStats);
limit.setStatistics(stats);
return stats;
}
@Override
public Statistics visitLogicalOneRowRelation(LogicalOneRowRelation oneRowRelation, DeriveContext context) {
Statistics stats = context.calculator.computeOneRowRelation(oneRowRelation.getProjects());
oneRowRelation.setStatistics(stats);
return stats;
}
@Override
public Statistics visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate, DeriveContext context) {
Statistics stats = aggregate.getStats();
if (stats == null || deepDerive) {
Statistics childStats = aggregate.child().accept(this, context);
stats = context.calculator.computeAggregate(aggregate, childStats);
aggregate.setStatistics(stats);
}
return stats;
}
@Override
public Statistics visitLogicalRepeat(LogicalRepeat<? extends Plan> repeat, DeriveContext context) {
Statistics childStats = repeat.child().accept(this, context);
Statistics stats = context.calculator.computeRepeat(repeat, childStats);
repeat.setStatistics(stats);
return stats;
}
@Override
public Statistics visitLogicalFilter(LogicalFilter<? extends Plan> filter, DeriveContext context) {
Statistics stats = filter.getStats();
if (stats == null || deepDerive) {
Statistics childStats = filter.child().accept(this, context);
stats = context.calculator.computeFilter(filter, childStats);
filter.setStatistics(stats);
}
return stats;
}
@Override
public Statistics visitLogicalOlapScan(LogicalOlapScan olapScan, DeriveContext context) {
Statistics stats = olapScan.getStats();
if (stats == null || deepDerive) {
stats = context.calculator.computeOlapScan(olapScan);
olapScan.setStatistics(stats);
}
return stats;
}
@Override
public Statistics visitLogicalDeferMaterializeOlapScan(LogicalDeferMaterializeOlapScan olapScan,
DeriveContext context) {
Statistics stats = context.calculator.computeOlapScan(olapScan);
olapScan.setStatistics(stats);
return stats;
}
@Override
public Statistics visitLogicalCatalogRelation(LogicalCatalogRelation relation, DeriveContext context) {
Statistics stats = relation.getStats();
if (stats == null || deepDerive) {
stats = context.calculator.computeCatalogRelation(relation);
relation.setStatistics(stats);
}
return stats;
}
@Override
public Statistics visitLogicalTVFRelation(LogicalTVFRelation tvfRelation, DeriveContext context) {
Statistics stats = tvfRelation.getFunction().computeStats(tvfRelation.getOutput());
tvfRelation.setStatistics(stats);
return stats;
}
@Override
public Statistics visitLogicalSort(LogicalSort<? extends Plan> sort, DeriveContext context) {
Statistics stats = sort.getStats();
if (stats == null || deepDerive) {
stats = sort.child().accept(this, context);
sort.setStatistics(stats);
}
return stats;
}
@Override
public Statistics visitLogicalTopN(LogicalTopN<? extends Plan> topN, DeriveContext context) {
Statistics stats = topN.getStats();
if (stats == null || deepDerive) {
Statistics childStats = topN.child().accept(this, context);
stats = context.calculator.computeTopN(topN, childStats);
topN.setStatistics(stats);
}
return stats;
}
@Override
public Statistics visitLogicalDeferMaterializeTopN(LogicalDeferMaterializeTopN<? extends Plan> topN,
DeriveContext context) {
Statistics stats = topN.getStats();
if (stats == null && deepDerive) {
Statistics childStats = topN.child().accept(this, context);
stats = context.calculator.computeTopN(topN, childStats);
topN.setStatistics(stats);
}
return stats;
}
@Override
public Statistics visitLogicalPartitionTopN(LogicalPartitionTopN<? extends Plan> partitionTopN,
DeriveContext context) {
Statistics stats = partitionTopN.getStats();
if (stats == null || deepDerive) {
Statistics childStats = partitionTopN.child().accept(this, context);
stats = context.calculator.computePartitionTopN(partitionTopN, childStats);
partitionTopN.setStatistics(stats);
}
return stats;
}
@Override
public Statistics visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, DeriveContext context) {
Statistics joinStats = join.getStats();
if (joinStats == null || deepDerive) {
Statistics leftStats = join.left().accept(this, context);
Statistics rightStats = join.right().accept(this, context);
joinStats = context.calculator.computeJoin(join, leftStats,
rightStats);
joinStats = new StatisticsBuilder(joinStats).setWidthInJoinCluster(
leftStats.getWidthInJoinCluster() + rightStats.getWidthInJoinCluster()).build();
join.setStatistics(joinStats);
}
return joinStats;
}
@Override
public Statistics visitLogicalAssertNumRows(
LogicalAssertNumRows<? extends Plan> assertNumRows, DeriveContext context) {
Statistics childStats = assertNumRows.child().accept(this, context);
Statistics stats = context.calculator.computeAssertNumRows(assertNumRows.getAssertNumRowsElement(),
childStats);
assertNumRows.setStatistics(stats);
return stats;
}
@Override
public Statistics visitLogicalUnion(
LogicalUnion union, DeriveContext context) {
Statistics stats = union.getStats();
if (stats == null || deepDerive) {
List<Statistics> childrenStats = new ArrayList<>();
for (Plan child : union.children()) {
Statistics childStats = child.accept(this, context);
childrenStats.add(childStats);
}
stats = context.calculator.computeUnion(union, childrenStats);
union.setStatistics(stats);
}
return stats;
}
@Override
public Statistics visitLogicalExcept(
LogicalExcept except, DeriveContext context) {
Statistics stats = except.getStats();
if (stats == null || deepDerive) {
for (Plan child : except.children()) {
Statistics childStats = child.accept(this, context);
if (stats == null) {
stats = childStats;
}
}
stats = context.calculator.computeExcept(except, stats);
except.setStatistics(stats);
}
return stats;
}
@Override
public Statistics visitLogicalIntersect(
LogicalIntersect intersect, DeriveContext context) {
Statistics stats = intersect.getStats();
if (stats == null || deepDerive) {
List<Statistics> childrenStats = new ArrayList<>();
for (Plan child : intersect.children()) {
Statistics childStats = child.accept(this, context);
childrenStats.add(childStats);
}
stats = context.calculator.computeIntersect(intersect, childrenStats);
intersect.setStatistics(stats);
}
return stats;
}
@Override
public Statistics visitLogicalGenerate(LogicalGenerate<? extends Plan> generate, DeriveContext context) {
Statistics stats = generate.getStats();
if (stats == null || deepDerive) {
Statistics childStats = generate.child().accept(this, context);
stats = context.calculator.computeGenerate(generate, childStats);
generate.setStatistics(stats);
}
return stats;
}
@Override
public Statistics visitLogicalWindow(LogicalWindow<? extends Plan> window, DeriveContext context) {
Statistics stats = window.getStats();
if (stats == null || deepDerive) {
Statistics childStats = window.child().accept(this, context);
stats = context.calculator.computeWindow(window, childStats);
window.setStatistics(stats);
}
return stats;
}
@Override
public Statistics visitLogicalCTEProducer(LogicalCTEProducer<? extends Plan> cteProducer, DeriveContext context) {
Statistics prodStats = cteProducer.child().accept(this, context);
StatisticsBuilder builder = new StatisticsBuilder(prodStats);
builder.setWidthInJoinCluster(1);
Statistics stats = builder.build();
cteProducer.setStatistics(stats);
ConnectContext.get().getStatementContext().setProducerStats(cteProducer.getCteId(), stats);
return stats;
}
@Override
public Statistics visitLogicalCTEConsumer(LogicalCTEConsumer cteConsumer, DeriveContext context) {
CTEId cteId = cteConsumer.getCteId();
Statistics prodStats = ConnectContext.get().getStatementContext().getProducerStatsByCteId(cteId);
Preconditions.checkArgument(prodStats != null, String.format("Stats for CTE: %s not found", cteId));
Statistics consumerStats = new Statistics(prodStats.getRowCount(), 1, new HashMap<>());
for (Slot slot : cteConsumer.getOutput()) {
Slot prodSlot = cteConsumer.getProducerSlot(slot);
ColumnStatistic colStats = prodStats.columnStatistics().get(prodSlot);
if (colStats == null) {
continue;
}
consumerStats.addColumnStats(slot, colStats);
}
cteConsumer.setStatistics(consumerStats);
return consumerStats;
}
@Override
public Statistics visitLogicalCTEAnchor(LogicalCTEAnchor<? extends Plan, ? extends Plan> cteAnchor,
DeriveContext context) {
Statistics childStats = null;
for (Plan child : cteAnchor.children()) {
childStats = child.accept(this, context);
}
Preconditions.checkArgument(childStats != null, "cteAnchor's childStats is null");
// use consumer stats
cteAnchor.setStatistics(childStats);
return childStats;
}
/**
* used for ut
*/
@Override
public Statistics visitLogicalRelation(LogicalRelation relation, DeriveContext context) {
StatisticsBuilder builder = new StatisticsBuilder();
builder.setRowCount(1);
relation.getOutput().forEach(slot -> builder.putColumnStatistics(slot, ColumnStatistic.UNKNOWN));
Statistics stats = builder.build();
relation.setStatistics(stats);
return stats;
}
}