NormalizeOlapTableStreamScan.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.analysis.TableScanParams;
import org.apache.doris.binlog.BinlogUtils;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.OlapTableWrapper;
import org.apache.doris.catalog.RowBinlogTableWrapper;
import org.apache.doris.catalog.stream.BaseTableStream;
import org.apache.doris.catalog.stream.OlapTableStreamWrapper;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PreAggStatus;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapTableStreamScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.planner.OlapScanNode;
import org.apache.doris.qe.ConnectContext;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/**
 * 1. remove STREAM_CHANGE_TYPE_VIRTUAL_COLUMN & STREAM_SEQ_VIRTUAL_COLUMN from olap table stream scan output
 * with alias projection
 * 2. add delete sign column if unique base table
 */
public class NormalizeOlapTableStreamScan extends OneRewriteRuleFactory {
    @Override
    public Rule build() {
        return logicalOlapTableStreamScan()
                .thenApply(ctx -> normalize(ctx.root, ctx.cascadesContext))
                .toRule(RuleType.NORMALIZE_OlAP_TABLE_STREAM_SCAN);
    }

    private static Expression buildChangeTypeExpr(Slot opSlot) {
        return new CaseWhen(ImmutableList.of(
                new WhenClause(new EqualTo(opSlot, new BigIntLiteral(BinlogUtils.ROW_BINLOG_APPEND)),
                        new VarcharLiteral("APPEND")),
                new WhenClause(new EqualTo(opSlot, new BigIntLiteral(BinlogUtils.ROW_BINLOG_DELETE)),
                        new VarcharLiteral("DELETE")),
                new WhenClause(new EqualTo(opSlot, new BigIntLiteral(BinlogUtils.ROW_BINLOG_UPDATE_BEFORE)),
                        new VarcharLiteral("UPDATE_BEFORE")),
                new WhenClause(new EqualTo(opSlot, new BigIntLiteral(BinlogUtils.ROW_BINLOG_UPDATE_AFTER)),
                        new VarcharLiteral("UPDATE_AFTER"))), new VarcharLiteral("UNKNOWN"));
    }

    private Plan normalize(LogicalOlapTableStreamScan scan, CascadesContext cascadesContext) {
        List<Long> selectedPartitionIds = scan.getSelectedPartitionIds();
        if (selectedPartitionIds.isEmpty()) {
            return scan;
        }
        if (scan.isReset()) {
            return makeResetOlapFullScan(scan, cascadesContext);
        }
        if (scan.isSnapshot()) {
            return makeSnapshotScan(scan, cascadesContext);
        }
        OlapTableStreamWrapper streamWrapper = (OlapTableStreamWrapper) scan.getTable();
        OlapTable baseTable = streamWrapper.getBaseTable();
        List<Long> historicalPartitionIds =
                ImmutableList.copyOf(streamWrapper.filterHistoryPartitionIds(selectedPartitionIds));
        List<Long> incrementalPartitionIds =
                ImmutableList.copyOf(streamWrapper.filterIncrementalPartitionIds(selectedPartitionIds));
        Plan historyPlan = null;
        Plan incrementalPlan = null;
        List<Slot> originSlots = scan.getLogicalProperties().getOutput();
        // newSlots = originSlots - (STREAM_CHANGE_TYPE_VIRTUAL_COLUMN + STREAM_SEQ_VIRTUAL_COLUMN)
        List<Slot> newSlots = ImmutableList.copyOf(originSlots.stream()
                .filter(slot -> !(slot instanceof SlotReference
                        && ((SlotReference) slot).getOriginalColumn().isPresent()
                        && ((SlotReference) slot).getOriginalColumn().get()
                        .equals(Column.STREAM_CHANGE_TYPE_VIRTUAL_COLUMN)))
                .filter(slot -> !(slot instanceof SlotReference
                        && ((SlotReference) slot).getOriginalColumn().isPresent()
                        && ((SlotReference) slot).getOriginalColumn().get()
                        .equals(Column.STREAM_SEQ_VIRTUAL_COLUMN)))
                .collect(Collectors.toList()));

        // history plan
        if (!historicalPartitionIds.isEmpty()) {
            // for not consume history partition we just scan base table
            LogicalOlapScan baseScan = new LogicalOlapScan(
                    cascadesContext.getStatementContext().getNextRelationId(), baseTable, scan.qualified(),
                    historicalPartitionIds, scan.getSelectedTabletIds(), new ArrayList<>(), scan.getTableSample(),
                    ImmutableList.of());

            List<Slot> baseOutputSlots = baseScan.getLogicalProperties().getOutput();
            Plan plan = baseScan;
            Slot deleteSlot = null;
            Slot tsoSlot = null;
            for (Slot slot : baseOutputSlots) {
                if (slot.getName().equals(Column.DELETE_SIGN)) {
                    deleteSlot = slot;
                }
                if (slot.getName().equals(Column.COMMIT_TSO_COL)) {
                    tsoSlot = slot;
                }
                if (deleteSlot != null && tsoSlot != null) {
                    break;
                }
            }
            if (deleteSlot != null) {
                Expression conjunct = new EqualTo(deleteSlot, new TinyIntLiteral((byte) 0));
                if (!scan.getTable().getEnableUniqueKeyMergeOnWrite()) {
                    plan = baseScan.withPreAggStatus(PreAggStatus.off(
                            Column.DELETE_SIGN + " is used as conjuncts."));
                }
                plan = new LogicalFilter<>(ImmutableSet.of(conjunct), plan);
            }
            Preconditions.checkArgument(tsoSlot != null);
            List<NamedExpression> project = mapToChildOutputSlots(newSlots, baseOutputSlots);
            for (Slot slot : originSlots) {
                if (slot instanceof SlotReference
                        && ((SlotReference) slot).getOriginalColumn().isPresent()
                        && ((SlotReference) slot).getOriginalColumn().get()
                        .equals(Column.STREAM_CHANGE_TYPE_VIRTUAL_COLUMN)) {
                    project.add(new Alias(slot.getExprId(), new VarcharLiteral("APPEND"),
                            Column.STREAM_CHANGE_TYPE_COL));
                }
                if (slot instanceof SlotReference
                        && ((SlotReference) slot).getOriginalColumn().isPresent()
                        && ((SlotReference) slot).getOriginalColumn().get()
                        .equals(Column.STREAM_SEQ_VIRTUAL_COLUMN)) {
                    project.add(new Alias(slot.getExprId(), tsoSlot, Column.STREAM_SEQ_COL));
                }
            }
            historyPlan = new LogicalProject<>(project, plan);
        }

        // incremental plan
        if (!incrementalPartitionIds.isEmpty()) {
            // remap scan from binlog
            incrementalPlan = makeIncrementalScanFromBinlog(cascadesContext, scan, incrementalPartitionIds,
                    baseTable, streamWrapper.getPartitionOffsets(incrementalPartitionIds),
                    streamWrapper.getStreamScanType(), originSlots, newSlots, true);
        }

        return combineTwoPlan(historyPlan, incrementalPlan, originSlots);
    }

    private Plan refreshUnionChildOutputExprIds(Plan plan, List<Slot> unionOutputs) {
        Preconditions.checkState(plan.getOutput().size() == unionOutputs.size(),
                "Union child output size %s does not match union output size %s",
                plan.getOutput().size(), unionOutputs.size());
        List<NamedExpression> project = new ArrayList<>(plan.getOutput().size());
        for (int i = 0; i < plan.getOutput().size(); i++) {
            project.add(new Alias(plan.getOutput().get(i), unionOutputs.get(i).getName()));
        }
        return new LogicalProject<>(project, plan);
    }

    /**
     * Build a projection list that exposes the columns of {@code wantedSlots} (slots taken from the
     * original stream scan output) on top of the rewritten child whose output is {@code childOutput}.
     * Each wanted slot is matched to the child output slot by column name and re-aliased to the
     * original expr id, so that operators above this rewrite still reference the same expr ids.
     */
    private List<NamedExpression> mapToChildOutputSlots(List<Slot> wantedSlots, List<Slot> childOutput) {
        Map<String, Slot> childSlotByName = new HashMap<>();
        for (Slot slot : childOutput) {
            childSlotByName.put(slot.getName(), slot);
        }
        List<NamedExpression> project = new ArrayList<>(wantedSlots.size());
        for (Slot wanted : wantedSlots) {
            Slot match = childSlotByName.get(wanted.getName());
            Preconditions.checkArgument(match != null,
                    "column %s not found in child output", wanted.getName());
            if (match.getExprId().equals(wanted.getExprId())) {
                project.add(match);
            } else {
                project.add(new Alias(wanted.getExprId(), match, wanted.getName()));
            }
        }
        return project;
    }

    private Plan makeIncrementalScanFromBinlog(CascadesContext cascadesContext, LogicalOlapTableStreamScan scan,
                                               List<Long> selectedPartitionIds,
                                               OlapTable baseTable, Map<Long, Pair<Long, Long>> offsetMap,
                                               BaseTableStream.StreamScanType streamScanType, List<Slot> originSlots,
                                               List<Slot> newSlots, boolean isIncremental) {
        // remap scan from binlog
        RowBinlogTableWrapper table =
                new RowBinlogTableWrapper(baseTable, offsetMap);
        Map<String, String> scanParams = new HashMap<>();
        scanParams.put(OlapScanNode.OLAP_INCREMENT_TYPE, streamScanType.toString());
        LogicalOlapScan newScan = new LogicalOlapScan(cascadesContext.getStatementContext().getNextRelationId(),
                table, scan.qualified(), selectedPartitionIds, scan.getSelectedTabletIds(),
                new ArrayList<>(), scan.getTableSample(), ImmutableList.of())
                .withTableScanParams(new TableScanParams(TableScanParams.INCREMENTAL_READ, scanParams,
                        Lists.newArrayList()));
        Plan plan = newScan;
        List<Slot> binlogOutputSlots = newScan.getLogicalProperties().getOutput();
        // project stream virtual slot from binlog
        Slot opSlot = null;
        Slot seqSlot = null;
        for (int i = 0; i < binlogOutputSlots.size(); i++) {
            if (binlogOutputSlots.get(i).getName().equals(Column.BINLOG_TIMESTAMP_COL)) {
                seqSlot = binlogOutputSlots.get(i);
            } else if (binlogOutputSlots.get(i).getName().equals(Column.BINLOG_OPERATION_COL)) {
                opSlot = binlogOutputSlots.get(i);
            }
        }
        if (streamScanType.equals(BaseTableStream.StreamScanType.APPEND_ONLY)) {
            // filter append-only operation if needed
            Preconditions.checkArgument(opSlot != null);
            plan = new LogicalFilter<>(ImmutableSet.of(new EqualTo(opSlot,
                    new BigIntLiteral(BinlogUtils.ROW_BINLOG_APPEND))), plan);
        }
        List<NamedExpression> project = binlogOutputSlots.stream()
                .map(NamedExpression.class::cast).collect(Collectors.toList());
        plan = new LogicalProject<>(project, plan);
        project = mapToChildOutputSlots(newSlots, binlogOutputSlots);
        if (isIncremental) {
            // replace stream virtual column with alias slot reference
            for (Slot slot : originSlots) {
                if (slot instanceof SlotReference
                        && ((SlotReference) slot).getOriginalColumn().isPresent()
                        && ((SlotReference) slot).getOriginalColumn().get()
                        .equals(Column.STREAM_CHANGE_TYPE_VIRTUAL_COLUMN)) {
                    project.add(new Alias(slot.getExprId(), buildChangeTypeExpr(opSlot),
                            Column.STREAM_CHANGE_TYPE_COL));
                } else if (slot instanceof SlotReference
                        && ((SlotReference) slot).getOriginalColumn().isPresent()
                        && ((SlotReference) slot).getOriginalColumn().get()
                        .equals(Column.STREAM_SEQ_VIRTUAL_COLUMN)) {
                    project.add(new Alias(slot.getExprId(), seqSlot, Column.STREAM_SEQ_COL));
                }
            }
        } else {
            // only filter delete & update before rows for building before snapshot image
            Preconditions.checkArgument(opSlot != null);
            Expression opFilter = new InPredicate(opSlot, ImmutableList.of(
                    new BigIntLiteral(BinlogUtils.ROW_BINLOG_DELETE),
                    new BigIntLiteral(BinlogUtils.ROW_BINLOG_UPDATE_BEFORE)));
            plan = new LogicalFilter<>(ImmutableSet.of(opFilter), plan);
        }
        return new LogicalProject<>(project, plan);
    }

    private Plan makeResetOlapFullScan(LogicalOlapTableStreamScan scan, CascadesContext cascadesContext) {
        // make olap scan on base table
        OlapTableStreamWrapper streamWrapper = (OlapTableStreamWrapper) scan.getTable();
        OlapTable baseTable = streamWrapper.getBaseTable();
        return makeOlapScanOnBaseTable(scan, cascadesContext, baseTable, scan.getSelectedPartitionIds());
    }

    private Plan makeSnapshotScan(LogicalOlapTableStreamScan scan, CascadesContext cascadesContext) {
        List<Long> selectedPartitionIds = scan.getSelectedPartitionIds();
        OlapTableStreamWrapper streamWrapper = (OlapTableStreamWrapper) scan.getTable();
        OlapTable baseTable = streamWrapper.getBaseTable();
        List<Long> normalPartitionIds = streamWrapper.filterNormalSnapshotPartitionIds(selectedPartitionIds);
        Set<Long> normalPartitionIdSet = ImmutableSet.copyOf(normalPartitionIds);
        List<Long> rebuildPartitionIds =
                ImmutableList.copyOf(selectedPartitionIds.stream()
                        .filter(id -> !normalPartitionIdSet.contains(id)).collect(Collectors.toList()));
        List<Slot> originSlots = scan.getLogicalProperties().getOutput();
        Plan normalPlan = null;
        Plan rebuildPlan = null;
        if (!normalPartitionIds.isEmpty()) {
            normalPlan = makeOlapScanOnBaseTable(scan, cascadesContext, baseTable, normalPartitionIds);
            // project filter invisible slots to match rebuild Plan, and keep the original
            // stream scan expr ids so that parent operators still reference them
            List<NamedExpression> project = mapToChildOutputSlots(originSlots, normalPlan.getOutput());
            normalPlan = new LogicalProject<>(project, normalPlan);
        }
        if (!rebuildPartitionIds.isEmpty()) {
            // base table scan part
            // build base table offset
            // for row commit tso <= consumption tso we scan from base table
            Map<Long, Pair<Long, Long>> partitionOffsetMap =
                    streamWrapper.getHistoryPartitionOffsets(rebuildPartitionIds);
            OlapTableWrapper table =
                    new OlapTableWrapper(baseTable, partitionOffsetMap);
            Plan basePartPlan = makeOlapScanOnBaseTable(scan, cascadesContext, table, rebuildPartitionIds);
            // binlog scan part
            // we rebuild by add back updated & deleted rows from binlog
            Plan binlogPartPlan = makeIncrementalScanFromBinlog(cascadesContext, scan, rebuildPartitionIds,
                    baseTable, streamWrapper.getPartitionOffsets(rebuildPartitionIds),
                    BaseTableStream.StreamScanType.MIN_DELTA, originSlots, originSlots, false);
            rebuildPlan = combineTwoPlan(basePartPlan, binlogPartPlan, originSlots);
        }
        return combineTwoPlan(normalPlan, rebuildPlan, originSlots);
    }

    private Plan makeOlapScanOnBaseTable(LogicalOlapTableStreamScan scan, CascadesContext cascadesContext,
                                         OlapTable baseTable, List<Long> partitionIds) {
        LogicalOlapScan baseScan = new LogicalOlapScan(cascadesContext.getStatementContext().getNextRelationId(),
                baseTable, scan.qualified(), partitionIds, scan.getSelectedTabletIds(),
                new ArrayList<>(), scan.getTableSample(), ImmutableList.of());
        Plan plan = baseScan;
        Slot deleteSlot = null;
        List<Slot> baseOutputSlots = baseScan.getLogicalProperties().getOutput();
        for (Slot slot : baseOutputSlots) {
            if (slot.getName().equals(Column.DELETE_SIGN)) {
                deleteSlot = slot;
            }
            if (deleteSlot != null) {
                break;
            }
        }
        if (deleteSlot != null) {
            Expression conjunct = new EqualTo(deleteSlot, new TinyIntLiteral((byte) 0));
            if (!scan.getTable().getEnableUniqueKeyMergeOnWrite()) {
                plan = baseScan.withPreAggStatus(PreAggStatus.off(
                        Column.DELETE_SIGN + " is used as conjuncts."));
            }
            plan = new LogicalFilter<>(ImmutableSet.of(conjunct), plan);
        }
        return plan;
    }

    private Plan combineTwoPlan(Plan plan0, Plan plan1, List<Slot> originSlots) {
        if (plan0 == null && plan1 == null) {
            return new LogicalEmptyRelation(ConnectContext.get().getStatementContext().getNextRelationId(),
                    originSlots);
        } else if (plan0 == null) {
            return plan1;
        } else if (plan1 == null) {
            return plan0;
        }
        return makeUnionPlan(plan0, plan1, originSlots);
    }

    private Plan makeUnionPlan(Plan child0, Plan child1, List<Slot> originSlots) {
        child0 = refreshUnionChildOutputExprIds(child0, originSlots);
        child1 = refreshUnionChildOutputExprIds(child1, originSlots);
        // return union plan
        List<Plan> children = Lists.newArrayList(child0, child1);
        return new LogicalUnion(Qualifier.ALL,
                originSlots.stream().map(NamedExpression.class::cast).collect(Collectors.toList()),
                children.stream()
                        .map(plan -> plan.getOutput().stream()
                                .map(slot -> (SlotReference) slot.toSlot())
                                .collect(Collectors.toList()))
                        .collect(Collectors.toList()),
                ImmutableList.of(),
                false,
                children);
    }
}