VariantSubPathPruning.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.rewrite.ColumnPruning.PruneContext;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt;
import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalPartitionTopN;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.types.VariantType;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
/**
* prune sub path of variant type slot.
* for example, variant slot v in table t has two sub path: 'c1' and 'c2'
* after this rule, select v['c1'] from t will only scan one sub path 'c1' of v to reduce scan time
*
* This rule accomplishes all the work using two components. The Collector traverses from the top down,
* collecting all the element_at functions on the variant types, and recording the required path from
* the original variant slot to the current element_at. The Replacer traverses from the bottom up,
* generating the slots for the required sub path on scan, union, and cte consumer.
* Then, it replaces the element_at with the corresponding slot.
*/
public class VariantSubPathPruning extends DefaultPlanRewriter<PruneContext> implements CustomRewriter {
@Override
public Plan rewriteRoot(Plan plan, JobContext jobContext) {
Context context = new Context();
plan.accept(VariantSubPathCollector.INSTANCE, context);
if (context.elementAtToSubPathMap.isEmpty()) {
return plan;
} else {
return plan.accept(VariantSubPathReplacer.INSTANCE, context);
}
}
private static class Context {
// user for collector
// record slot to its original expr. for example, Alias(c1, a1) will put a1 -> c1 to this map
private final Map<Slot, Expression> slotToOriginalExprMap = Maps.newHashMap();
// record element_at to root slot with sub path. for example, element_at(c1, 'a') as c2 + element_at(c2, 'b')
// will put element(c2, 'b') -> {c1, ['a', 'b']} and element_at(c1, 'a') -> {c1, ['a']} to this map
private final Map<ElementAt, Pair<SlotReference, List<String>>> elementAtToSubPathMap = Maps.newHashMap();
// record sub path need from slot. for example, element_at(c1, 'a') as c2 + element_at(c2, 'b')
// will put c1 -> [['a'], ['a', 'b']] to this map
private final Map<SlotReference, Set<List<String>>> slotToSubPathsMap = Maps.newHashMap();
// we need to record elementAt to consumer slot, and generate right slot when do consumer slot replace
private final Map<ElementAt, SlotReference> elementAtToCteConsumer = Maps.newHashMap();
// use for replacer
// record element_at should be replaced with which slot
private final Map<ElementAt, SlotReference> elementAtToSlotMap = Maps.newHashMap();
// record which slots of prefix-matched sub paths need to be replaced
// in addition to the slots of the exactly matched sub path.
// for example, we have element_at(c1, 'a') as c2 + element_at(c2, 'b')
// if element_at(c1, 'a') -> s1, element_at(c2, 'b') -> s2, then
// in this map we have element_at(c1, 'a') -> {['a', 'b'] -> s2}
// this is used in replace element_at in project. since upper node may need its sub path, so we must put all
// slot could be generated from it into project list.
private final Map<ElementAt, Map<List<String>, SlotReference>> elementAtToSlotsMap = Maps.newHashMap();
// same as elementAtToSlotsMap, record variant slot should be replaced by which slots.
private final Map<Slot, Map<List<String>, SlotReference>> slotToSlotsMap = Maps.newHashMap();
public void putSlotToOriginal(Slot slot, Expression expression) {
this.slotToOriginalExprMap.put(slot, expression);
// update existed entry
// element_at(3, c) -> 3, ['c']
// +
// slot3 -> element_at(1, b) -> 1, ['b']
// ==>
// element_at(3, c) -> 1, ['b', 'c']
for (Map.Entry<ElementAt, Pair<SlotReference, List<String>>> entry : elementAtToSubPathMap.entrySet()) {
ElementAt elementAt = entry.getKey();
Pair<SlotReference, List<String>> oldSlotSubPathPair = entry.getValue();
if (slot.equals(oldSlotSubPathPair.first)) {
if (expression instanceof ElementAt) {
Pair<SlotReference, List<String>> newSlotSubPathPair = elementAtToSubPathMap.get(expression);
List<String> newPath = Lists.newArrayList(newSlotSubPathPair.second);
newPath.addAll(oldSlotSubPathPair.second);
elementAtToSubPathMap.put(elementAt, Pair.of(newSlotSubPathPair.first, newPath));
slotToSubPathsMap.computeIfAbsent(newSlotSubPathPair.first,
k -> Sets.newHashSet()).add(newPath);
} else if (expression instanceof Slot) {
Pair<SlotReference, List<String>> newSlotSubPathPair
= Pair.of((SlotReference) expression, oldSlotSubPathPair.second);
elementAtToSubPathMap.put(elementAt, newSlotSubPathPair);
}
}
}
if (expression instanceof SlotReference && slotToSubPathsMap.containsKey((SlotReference) slot)) {
Set<List<String>> subPaths = slotToSubPathsMap
.computeIfAbsent((SlotReference) expression, k -> Sets.newHashSet());
subPaths.addAll(slotToSubPathsMap.get(slot));
}
}
public void putElementAtToSubPath(ElementAt elementAt,
Pair<SlotReference, List<String>> pair, Slot parent) {
this.elementAtToSubPathMap.put(elementAt, pair);
Set<List<String>> subPaths = slotToSubPathsMap.computeIfAbsent(pair.first, k -> Sets.newHashSet());
subPaths.add(pair.second);
if (parent != null) {
for (List<String> parentSubPath : slotToSubPathsMap.computeIfAbsent(
(SlotReference) parent, k -> Sets.newHashSet())) {
List<String> subPathWithParents = Lists.newArrayList(pair.second);
subPathWithParents.addAll(parentSubPath);
subPaths.add(subPathWithParents);
}
}
}
public void putAllElementAtToSubPath(Map<ElementAt, Pair<SlotReference, List<String>>> elementAtToSubPathMap) {
for (Map.Entry<ElementAt, Pair<SlotReference, List<String>>> entry : elementAtToSubPathMap.entrySet()) {
putElementAtToSubPath(entry.getKey(), entry.getValue(), null);
}
}
}
private static class VariantSubPathReplacer extends DefaultPlanRewriter<Context> {
public static VariantSubPathReplacer INSTANCE = new VariantSubPathReplacer();
@Override
public Plan visitLogicalOlapScan(LogicalOlapScan olapScan, Context context) {
List<Slot> outputs = olapScan.getOutput();
Map<String, Set<List<String>>> colToSubPaths = Maps.newHashMap();
for (Slot slot : outputs) {
if (slot.getDataType() instanceof VariantType
&& context.slotToSubPathsMap.containsKey((SlotReference) slot)) {
Set<List<String>> subPaths = context.slotToSubPathsMap.get(slot);
if (((SlotReference) slot).getOriginalColumn().isPresent()) {
colToSubPaths.put(((SlotReference) slot).getOriginalColumn().get().getName(), subPaths);
}
}
}
LogicalOlapScan newScan = olapScan.withColToSubPathsMap(colToSubPaths);
Map<Slot, Map<List<String>, SlotReference>> oriSlotToSubPathToSlot = newScan.getSubPathToSlotMap();
generateElementAtMaps(context, oriSlotToSubPathToSlot);
return newScan;
}
@Override
public Plan visitLogicalUnion(LogicalUnion union, Context context) {
union = (LogicalUnion) this.visit(union, context);
if (union.getQualifier() == Qualifier.DISTINCT) {
return super.visitLogicalUnion(union, context);
}
List<List<SlotReference>> regularChildrenOutputs
= Lists.newArrayListWithExpectedSize(union.getRegularChildrenOutputs().size());
List<List<NamedExpression>> constExprs
= Lists.newArrayListWithExpectedSize(union.getConstantExprsList().size());
for (int i = 0; i < union.getRegularChildrenOutputs().size(); i++) {
regularChildrenOutputs.add(Lists.newArrayListWithExpectedSize(union.getOutput().size() * 2));
}
for (int i = 0; i < union.getConstantExprsList().size(); i++) {
constExprs.add(Lists.newArrayListWithExpectedSize(union.getOutput().size() * 2));
}
List<NamedExpression> outputs = Lists.newArrayListWithExpectedSize(union.getOutput().size() * 2);
Map<Slot, Map<List<String>, SlotReference>> oriSlotToSubPathToSlot = Maps.newHashMap();
for (int i = 0; i < union.getOutput().size(); i++) {
// put back original slot
for (int j = 0; j < regularChildrenOutputs.size(); j++) {
regularChildrenOutputs.get(j).add(union.getRegularChildOutput(j).get(i));
}
for (int j = 0; j < constExprs.size(); j++) {
constExprs.get(j).add(union.getConstantExprsList().get(j).get(i));
}
outputs.add(union.getOutputs().get(i));
// if not variant, no need to process
if (!union.getOutput().get(i).getDataType().isVariantType()) {
continue;
}
// put new slots generated by sub path push down
Map<List<String>, List<SlotReference>> subPathSlots = Maps.newHashMap();
for (int j = 0; j < regularChildrenOutputs.size(); j++) {
List<SlotReference> regularChildOutput = union.getRegularChildOutput(j);
Expression output = regularChildOutput.get(i);
if (!context.slotToSlotsMap.containsKey(output)
|| !context.slotToSubPathsMap.containsKey(outputs.get(i))) {
// no sub path request for this column
continue;
}
// find sub path generated by union children
Expression key = output;
while (context.slotToOriginalExprMap.containsKey(key)) {
key = context.slotToOriginalExprMap.get(key);
}
List<String> subPathByChildren = Collections.emptyList();
if (key instanceof ElementAt) {
// this means need to find common sub path of its slots.
subPathByChildren = context.elementAtToSubPathMap.get(key).second;
}
for (Map.Entry<List<String>, SlotReference> entry : context.slotToSlotsMap.get(output).entrySet()) {
List<SlotReference> slotsForSubPath;
// remove subPath generated by children,
// because context only record sub path generated by parent
List<String> parentPaths = entry.getKey()
.subList(subPathByChildren.size(), entry.getKey().size());
if (!context.slotToSubPathsMap.get(outputs.get(i)).contains(parentPaths)) {
continue;
}
if (j == 0) {
// first child, need to put entry to subPathToSlots
slotsForSubPath = subPathSlots.computeIfAbsent(parentPaths, k -> Lists.newArrayList());
} else {
// other children, should find try from map. otherwise bug comes
if (!subPathSlots.containsKey(parentPaths)) {
throw new AnalysisException("push down variant sub path failed."
+ " cannot find sub path for child " + j + "."
+ " Sub path set is " + subPathSlots.keySet());
}
slotsForSubPath = subPathSlots.get(parentPaths);
}
slotsForSubPath.add(entry.getValue());
}
}
if (regularChildrenOutputs.isEmpty()) {
// use output sub paths exprs to generate subPathSlots
for (List<String> subPath : context.slotToSubPathsMap.get(outputs.get(i))) {
subPathSlots.put(subPath, ImmutableList.of((SlotReference) outputs.get(i).toSlot()));
}
}
for (Map.Entry<List<String>, List<SlotReference>> entry : subPathSlots.entrySet()) {
for (int j = 0; j < regularChildrenOutputs.size(); j++) {
regularChildrenOutputs.get(j).add(entry.getValue().get(j));
}
for (int j = 0; j < constExprs.size(); j++) {
NamedExpression constExpr = union.getConstantExprsList().get(j).get(i);
Expression pushDownExpr;
if (constExpr instanceof Alias) {
pushDownExpr = ((Alias) constExpr).child();
} else {
pushDownExpr = constExpr;
}
for (int sp = entry.getKey().size() - 1; sp >= 0; sp--) {
VarcharLiteral path = new VarcharLiteral(entry.getKey().get(sp));
pushDownExpr = new ElementAt(pushDownExpr, path);
}
constExprs.get(j).add(new Alias(pushDownExpr));
}
SlotReference outputSlot = new SlotReference(StatementScopeIdGenerator.newExprId(),
entry.getValue().get(0).getName(), VariantType.INSTANCE,
true, ImmutableList.of());
outputs.add(outputSlot);
// update element to slot map
Map<List<String>, SlotReference> s = oriSlotToSubPathToSlot.computeIfAbsent(
(Slot) outputs.get(i), k -> Maps.newHashMap());
s.put(entry.getKey(), outputSlot);
}
}
generateElementAtMaps(context, oriSlotToSubPathToSlot);
return union.withNewOutputsChildrenAndConstExprsList(outputs, union.children(),
regularChildrenOutputs, constExprs);
}
@Override
public Plan visitLogicalOneRowRelation(LogicalOneRowRelation oneRowRelation, Context context) {
ImmutableList.Builder<NamedExpression> newProjections
= ImmutableList.builderWithExpectedSize(oneRowRelation.getProjects().size());
for (NamedExpression projection : oneRowRelation.getProjects()) {
newProjections.add(projection);
newProjections.addAll(pushDownToProject(context, projection));
}
return oneRowRelation.withProjects(newProjections.build());
}
@Override
public Plan visitLogicalProject(LogicalProject<? extends Plan> project, Context context) {
project = (LogicalProject<? extends Plan>) this.visit(project, context);
ImmutableList.Builder<NamedExpression> newProjections
= ImmutableList.builderWithExpectedSize(project.getProjects().size());
for (NamedExpression projection : project.getProjects()) {
boolean addOthers = projection.getDataType().isVariantType();
if (projection instanceof SlotReference) {
newProjections.add(projection);
} else {
Expression child = ((Alias) projection).child();
NamedExpression newProjection;
if (child instanceof SlotReference) {
newProjection = projection;
} else if (child instanceof ElementAt) {
if (context.elementAtToSlotMap.containsKey(child)) {
newProjection = (NamedExpression) projection
.withChildren(context.elementAtToSlotMap.get(child));
} else {
addOthers = false;
newProjection = projection;
// try push element_at on this slot
if (extractSlotToSubPathPair((ElementAt) child) == null) {
newProjections.addAll(pushDownToProject(context, projection));
}
}
} else {
addOthers = false;
newProjection = (NamedExpression) ExpressionUtils.replace(
projection, context.elementAtToSlotMap);
// try push element_at on this slot
newProjections.addAll(pushDownToProject(context, projection));
}
newProjections.add(newProjection);
}
if (addOthers) {
Expression key = projection.toSlot();
while (key instanceof Slot && context.slotToOriginalExprMap.containsKey(key)) {
key = context.slotToOriginalExprMap.get(key);
}
if (key instanceof ElementAt && context.elementAtToSlotsMap.containsKey(key)) {
newProjections.addAll(context.elementAtToSlotsMap.get(key).values());
context.slotToSlotsMap.put(projection.toSlot(), context.elementAtToSlotsMap.get(key));
} else if (key instanceof Slot && context.slotToSlotsMap.containsKey(key)) {
newProjections.addAll(context.slotToSlotsMap.get(key).values());
context.slotToSlotsMap.put(projection.toSlot(), context.slotToSlotsMap.get(key));
}
}
}
return project.withProjects(newProjections.build());
}
@Override
public Plan visitLogicalCTEConsumer(LogicalCTEConsumer cteConsumer, Context context) {
if (cteConsumer.getProducerToConsumerOutputMap().keySet().stream()
.map(ExpressionTrait::getDataType).noneMatch(VariantType.class::isInstance)) {
return cteConsumer;
}
Map<Slot, Slot> consumerToProducerOutputMap = Maps.newHashMap();
Multimap<Slot, Slot> producerToConsumerOutputMap = LinkedHashMultimap.create();
Map<Slot, Map<List<String>, SlotReference>> oriSlotToSubPathToSlot = Maps.newHashMap();
for (Map.Entry<Slot, Slot> consumerToProducer : cteConsumer.getConsumerToProducerOutputMap().entrySet()) {
Slot consumer = consumerToProducer.getKey();
Slot producer = consumerToProducer.getValue();
consumerToProducerOutputMap.put(consumer, producer);
producerToConsumerOutputMap.put(producer, consumer);
if (!(consumer.getDataType() instanceof VariantType)) {
continue;
}
if (context.slotToSlotsMap.containsKey(producer)) {
Map<List<String>, SlotReference> consumerSlots = Maps.newHashMap();
for (Map.Entry<List<String>, SlotReference> producerSlot
: context.slotToSlotsMap.get(producer).entrySet()) {
SlotReference consumerSlot = LogicalCTEConsumer.generateConsumerSlot(
cteConsumer.getName(), producerSlot.getValue());
consumerToProducerOutputMap.put(consumerSlot, producerSlot.getValue());
producerToConsumerOutputMap.put(producerSlot.getValue(), consumerSlot);
consumerSlots.put(producerSlot.getKey(), consumerSlot);
}
context.slotToSlotsMap.put(consumer, consumerSlots);
oriSlotToSubPathToSlot.put(consumer, consumerSlots);
}
}
for (Entry<ElementAt, Pair<SlotReference, List<String>>> elementAtToSubPath
: context.elementAtToSubPathMap.entrySet()) {
ElementAt elementAt = elementAtToSubPath.getKey();
Pair<SlotReference, List<String>> slotWithSubPath = elementAtToSubPath.getValue();
SlotReference key = slotWithSubPath.first;
if (context.elementAtToCteConsumer.containsKey(elementAt)) {
key = context.elementAtToCteConsumer.get(elementAt);
}
// find exactly sub-path slot
if (oriSlotToSubPathToSlot.containsKey(key)) {
context.elementAtToSlotMap.put(elementAtToSubPath.getKey(),
oriSlotToSubPathToSlot.get(key).get(slotWithSubPath.second));
}
// find prefix sub-path slots
if (oriSlotToSubPathToSlot.containsKey(key)) {
Map<List<String>, SlotReference> subPathToSlotMap = oriSlotToSubPathToSlot.get(key);
for (Map.Entry<List<String>, SlotReference> subPathWithSlot : subPathToSlotMap.entrySet()) {
if (subPathWithSlot.getKey().size() > slotWithSubPath.second.size()
&& subPathWithSlot.getKey().subList(0, slotWithSubPath.second.size())
.equals(slotWithSubPath.second)) {
Map<List<String>, SlotReference> slots = context.elementAtToSlotsMap
.computeIfAbsent(elementAt, e -> Maps.newHashMap());
slots.put(subPathWithSlot.getKey(), subPathWithSlot.getValue());
}
}
}
}
return cteConsumer.withTwoMaps(consumerToProducerOutputMap, producerToConsumerOutputMap);
}
@Override
public Plan visitLogicalFilter(LogicalFilter<? extends Plan> filter, Context context) {
filter = (LogicalFilter<? extends Plan>) this.visit(filter, context);
ImmutableSet.Builder<Expression> newConjuncts
= ImmutableSet.builderWithExpectedSize(filter.getConjuncts().size());
for (Expression conjunct : filter.getConjuncts()) {
newConjuncts.add(ExpressionUtils.replace(conjunct, context.elementAtToSlotMap));
}
return filter.withConjuncts(newConjuncts.build());
}
@Override
public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, Context context) {
join = (LogicalJoin<? extends Plan, ? extends Plan>) this.visit(join, context);
ImmutableList.Builder<Expression> hashConditions
= ImmutableList.builderWithExpectedSize(join.getHashJoinConjuncts().size());
ImmutableList.Builder<Expression> otherConditions
= ImmutableList.builderWithExpectedSize(join.getOtherJoinConjuncts().size());
ImmutableList.Builder<Expression> markConditions
= ImmutableList.builderWithExpectedSize(join.getMarkJoinConjuncts().size());
for (Expression conjunct : join.getHashJoinConjuncts()) {
hashConditions.add(ExpressionUtils.replace(conjunct, context.elementAtToSlotMap));
}
for (Expression conjunct : join.getOtherJoinConjuncts()) {
otherConditions.add(ExpressionUtils.replace(conjunct, context.elementAtToSlotMap));
}
for (Expression conjunct : join.getMarkJoinConjuncts()) {
markConditions.add(ExpressionUtils.replace(conjunct, context.elementAtToSlotMap));
}
return join.withJoinConjuncts(hashConditions.build(), otherConditions.build(),
markConditions.build(), join.getJoinReorderContext());
}
@Override
public Plan visitLogicalSort(LogicalSort<? extends Plan> sort, Context context) {
sort = (LogicalSort<? extends Plan>) this.visit(sort, context);
ImmutableList.Builder<OrderKey> orderKeyBuilder
= ImmutableList.builderWithExpectedSize(sort.getOrderKeys().size());
for (OrderKey orderKey : sort.getOrderKeys()) {
orderKeyBuilder.add(orderKey.withExpression(
ExpressionUtils.replace(orderKey.getExpr(), context.elementAtToSlotMap)));
}
return sort.withOrderKeys(orderKeyBuilder.build());
}
@Override
public Plan visitLogicalTopN(LogicalTopN<? extends Plan> topN, Context context) {
topN = (LogicalTopN<? extends Plan>) this.visit(topN, context);
ImmutableList.Builder<OrderKey> orderKeyBuilder
= ImmutableList.builderWithExpectedSize(topN.getOrderKeys().size());
for (OrderKey orderKey : topN.getOrderKeys()) {
orderKeyBuilder.add(orderKey.withExpression(
ExpressionUtils.replace(orderKey.getExpr(), context.elementAtToSlotMap)));
}
return topN.withOrderKeys(orderKeyBuilder.build());
}
@Override
public Plan visitLogicalPartitionTopN(LogicalPartitionTopN<? extends Plan> partitionTopN, Context context) {
partitionTopN = (LogicalPartitionTopN<? extends Plan>) this.visit(partitionTopN, context);
ImmutableList.Builder<OrderExpression> orderKeyBuilder
= ImmutableList.builderWithExpectedSize(partitionTopN.getOrderKeys().size());
for (OrderExpression orderExpression : partitionTopN.getOrderKeys()) {
orderKeyBuilder.add(new OrderExpression(orderExpression.getOrderKey().withExpression(
ExpressionUtils.replace(orderExpression.getOrderKey().getExpr(), context.elementAtToSlotMap))
));
}
ImmutableList.Builder<Expression> partitionKeysBuilder
= ImmutableList.builderWithExpectedSize(partitionTopN.getPartitionKeys().size());
for (Expression partitionKey : partitionTopN.getPartitionKeys()) {
partitionKeysBuilder.add(ExpressionUtils.replace(partitionKey, context.elementAtToSlotMap));
}
return partitionTopN.withPartitionKeysAndOrderKeys(partitionKeysBuilder.build(), orderKeyBuilder.build());
}
@Override
public Plan visitLogicalGenerate(LogicalGenerate<? extends Plan> generate, Context context) {
generate = (LogicalGenerate<? extends Plan>) this.visit(generate, context);
ImmutableList.Builder<Function> generatorBuilder
= ImmutableList.builderWithExpectedSize(generate.getGenerators().size());
for (Function generator : generate.getGenerators()) {
generatorBuilder.add((Function) ExpressionUtils.replace(generator, context.elementAtToSlotMap));
}
return generate.withGenerators(generatorBuilder.build());
}
@Override
public Plan visitLogicalWindow(LogicalWindow<? extends Plan> window, Context context) {
window = (LogicalWindow<? extends Plan>) this.visit(window, context);
ImmutableList.Builder<NamedExpression> windowBuilder
= ImmutableList.builderWithExpectedSize(window.getWindowExpressions().size());
for (NamedExpression windowFunction : window.getWindowExpressions()) {
windowBuilder.add((NamedExpression) ExpressionUtils.replace(
windowFunction, context.elementAtToSlotMap));
}
return window.withExpressionsAndChild(windowBuilder.build(), window.child());
}
@Override
public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate, Context context) {
aggregate = (LogicalAggregate<? extends Plan>) this.visit(aggregate, context);
ImmutableList.Builder<NamedExpression> outputsBuilder
= ImmutableList.builderWithExpectedSize(aggregate.getOutputExpressions().size());
for (NamedExpression output : aggregate.getOutputExpressions()) {
outputsBuilder.add((NamedExpression) ExpressionUtils.replace(
output, context.elementAtToSlotMap));
}
ImmutableList.Builder<Expression> groupByKeysBuilder
= ImmutableList.builderWithExpectedSize(aggregate.getGroupByExpressions().size());
for (Expression groupByKey : aggregate.getGroupByExpressions()) {
groupByKeysBuilder.add(ExpressionUtils.replace(groupByKey, context.elementAtToSlotMap));
}
return aggregate.withGroupByAndOutput(groupByKeysBuilder.build(), outputsBuilder.build());
}
private List<NamedExpression> pushDownToProject(Context context, NamedExpression projection) {
if (!projection.getDataType().isVariantType()
|| !context.slotToSubPathsMap.containsKey((SlotReference) projection.toSlot())) {
return Collections.emptyList();
}
List<NamedExpression> newProjections = Lists.newArrayList();
Expression child = projection.child(0);
Map<List<String>, SlotReference> subPathToSlot = Maps.newHashMap();
Set<List<String>> subPaths = context.slotToSubPathsMap
.get((SlotReference) projection.toSlot());
for (List<String> subPath : subPaths) {
Expression pushDownExpr = child;
for (int i = subPath.size() - 1; i >= 0; i--) {
VarcharLiteral path = new VarcharLiteral(subPath.get(i));
pushDownExpr = new ElementAt(pushDownExpr, path);
}
Alias alias = new Alias(pushDownExpr);
newProjections.add(alias);
subPathToSlot.put(subPath, (SlotReference) alias.toSlot());
}
Map<Slot, Map<List<String>, SlotReference>> oriSlotToSubPathToSlot = Maps.newHashMap();
oriSlotToSubPathToSlot.put(projection.toSlot(), subPathToSlot);
generateElementAtMaps(context, oriSlotToSubPathToSlot);
return newProjections;
}
private void generateElementAtMaps(Context context, Map<Slot,
Map<List<String>, SlotReference>> oriSlotToSubPathToSlot) {
context.slotToSlotsMap.putAll(oriSlotToSubPathToSlot);
for (Entry<ElementAt, Pair<SlotReference, List<String>>> elementAtToSubPath
: context.elementAtToSubPathMap.entrySet()) {
ElementAt elementAt = elementAtToSubPath.getKey();
Pair<SlotReference, List<String>> slotWithSubPath = elementAtToSubPath.getValue();
// find exactly sub-path slot
if (oriSlotToSubPathToSlot.containsKey(slotWithSubPath.first)) {
context.elementAtToSlotMap.put(elementAtToSubPath.getKey(), oriSlotToSubPathToSlot.get(
slotWithSubPath.first).get(slotWithSubPath.second));
}
// find prefix sub-path slots
if (oriSlotToSubPathToSlot.containsKey(slotWithSubPath.first)) {
Map<List<String>, SlotReference> subPathToSlotMap = oriSlotToSubPathToSlot.get(
slotWithSubPath.first);
for (Map.Entry<List<String>, SlotReference> subPathWithSlot : subPathToSlotMap.entrySet()) {
if (subPathWithSlot.getKey().size() > slotWithSubPath.second.size()
&& subPathWithSlot.getKey().subList(0, slotWithSubPath.second.size())
.equals(slotWithSubPath.second)) {
Map<List<String>, SlotReference> slots = context.elementAtToSlotsMap
.computeIfAbsent(elementAt, e -> Maps.newHashMap());
slots.put(subPathWithSlot.getKey(), subPathWithSlot.getValue());
}
}
}
}
}
}
private static class VariantSubPathCollector extends PlanVisitor<Void, Context> {
public static VariantSubPathCollector INSTANCE = new VariantSubPathCollector();
/**
* Extract sequential element_at from expression tree.
* if extract success, put it into context map and stop traverse
* other-wise, traverse its children
*/
private static class ExtractSlotToSubPathPairFromTree
extends DefaultExpressionVisitor<Void, Map<ElementAt, Pair<SlotReference, List<String>>>> {
public static ExtractSlotToSubPathPairFromTree INSTANCE = new ExtractSlotToSubPathPairFromTree();
@Override
public Void visitElementAt(ElementAt elementAt, Map<ElementAt, Pair<SlotReference, List<String>>> context) {
Pair<SlotReference, List<String>> pair = extractSlotToSubPathPair(elementAt);
if (pair == null) {
visit(elementAt, context);
} else {
context.put(elementAt, pair);
}
return null;
}
}
@Override
public Void visitLogicalCTEAnchor(LogicalCTEAnchor<? extends Plan, ? extends Plan> cteAnchor,
Context context) {
cteAnchor.right().accept(this, context);
return cteAnchor.left().accept(this, context);
}
@Override
public Void visitLogicalCTEConsumer(LogicalCTEConsumer cteConsumer, Context context) {
for (Map.Entry<Slot, Slot> consumerToProducer : cteConsumer.getConsumerToProducerOutputMap().entrySet()) {
Slot consumer = consumerToProducer.getKey();
if (!(consumer.getDataType() instanceof VariantType)) {
continue;
}
Slot producer = consumerToProducer.getValue();
if (context.slotToSubPathsMap.containsKey((SlotReference) consumer)) {
Set<List<String>> subPaths = context.slotToSubPathsMap
.computeIfAbsent((SlotReference) producer, k -> Sets.newHashSet());
subPaths.addAll(context.slotToSubPathsMap.get(consumer));
}
for (Entry<ElementAt, Pair<SlotReference, List<String>>> elementAtToSubPath
: context.elementAtToSubPathMap.entrySet()) {
ElementAt elementAt = elementAtToSubPath.getKey();
Pair<SlotReference, List<String>> slotWithSubPath = elementAtToSubPath.getValue();
if (slotWithSubPath.first.equals(consumer)) {
context.elementAtToCteConsumer.putIfAbsent(elementAt, (SlotReference) consumer);
context.elementAtToSubPathMap.put(elementAt,
Pair.of((SlotReference) producer, slotWithSubPath.second));
}
}
}
return null;
}
@Override
public Void visitLogicalUnion(LogicalUnion union, Context context) {
if (union.getQualifier() == Qualifier.DISTINCT) {
return super.visitLogicalUnion(union, context);
}
for (List<SlotReference> childOutputs : union.getRegularChildrenOutputs()) {
for (int i = 0; i < union.getOutputs().size(); i++) {
Slot unionOutput = union.getOutput().get(i);
SlotReference childOutput = childOutputs.get(i);
if (context.slotToSubPathsMap.containsKey((SlotReference) unionOutput)) {
Set<List<String>> subPaths = context.slotToSubPathsMap
.computeIfAbsent(childOutput, k -> Sets.newHashSet());
subPaths.addAll(context.slotToSubPathsMap.get(unionOutput));
}
}
}
this.visit(union, context);
return null;
}
@Override
public Void visitLogicalProject(LogicalProject<? extends Plan> project, Context context) {
for (NamedExpression projection : project.getProjects()) {
if (!(projection instanceof Alias)) {
continue;
}
Alias alias = (Alias) projection;
Expression child = alias.child();
if (child instanceof SlotReference && child.getDataType() instanceof VariantType) {
context.putSlotToOriginal(alias.toSlot(), child);
}
// process expression like v['a']['b']['c'] just in root
// The reason for handling this situation separately is that
// it will have an impact on the upper level. So, we need to record the mapping of slots to it.
if (child instanceof ElementAt) {
Pair<SlotReference, List<String>> pair = extractSlotToSubPathPair((ElementAt) child);
if (pair != null) {
context.putElementAtToSubPath((ElementAt) child, pair, alias.toSlot());
context.putSlotToOriginal(alias.toSlot(), child);
continue;
}
}
// process other situation of expression like v['a']['b']['c']
Map<ElementAt, Pair<SlotReference, List<String>>> elementAtToSubPathMap = Maps.newHashMap();
child.accept(ExtractSlotToSubPathPairFromTree.INSTANCE, elementAtToSubPathMap);
context.putAllElementAtToSubPath(elementAtToSubPathMap);
}
this.visit(project, context);
return null;
}
@Override
public Void visit(Plan plan, Context context) {
Map<ElementAt, Pair<SlotReference, List<String>>> elementAtToSubPathMap = Maps.newHashMap();
for (Expression expression : plan.getExpressions()) {
expression.accept(ExtractSlotToSubPathPairFromTree.INSTANCE, elementAtToSubPathMap);
}
context.putAllElementAtToSubPath(elementAtToSubPathMap);
for (Plan child : plan.children()) {
child.accept(this, context);
}
return null;
}
}
protected static Pair<SlotReference, List<String>> extractSlotToSubPathPair(ElementAt elementAt) {
List<String> subPath = Lists.newArrayList();
while (true) {
if (!(elementAt.left().getDataType() instanceof VariantType)) {
return null;
}
if (!(elementAt.left() instanceof ElementAt || elementAt.left() instanceof SlotReference)) {
return null;
}
if (!(elementAt.right() instanceof StringLikeLiteral)) {
return null;
}
subPath.add(((StringLikeLiteral) elementAt.right()).getStringValue());
if (elementAt.left() instanceof SlotReference) {
// ElementAt's left child is SlotReference
// reverse subPath because we put them by reverse order
Collections.reverse(subPath);
return Pair.of((SlotReference) elementAt.left(), subPath);
} else {
// ElementAt's left child is ElementAt
elementAt = (ElementAt) elementAt.left();
}
}
}
}