AccessPathExpressionCollector.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.rules.rewrite.AccessPathExpressionCollector.CollectorContext;
import org.apache.doris.nereids.rules.rewrite.NestedColumnPruning.DataTypeAccessTree;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ArrayItemReference;
import org.apache.doris.nereids.trees.expressions.ArrayItemReference.ArrayItemSlot;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayCount;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayExists;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayFilter;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayFirst;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayFirstIndex;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayLast;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayLastIndex;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMap;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMatchAll;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMatchAny;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayReverseSplit;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySortBy;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySplit;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MapContainsEntry;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MapContainsKey;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MapContainsValue;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MapKeys;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MapValues;
import org.apache.doris.nereids.trees.expressions.functions.scalar.StructElement;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.NestedColumnPrunable;
import org.apache.doris.nereids.types.StructField;
import org.apache.doris.nereids.types.StructType;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.thrift.TAccessPathType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Maps;
import com.google.common.collect.Multimap;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Stack;
/**
* collect the access path, for example: `select struct_element(s, 'data')` has access path: ['s', 'data']
*/
public class AccessPathExpressionCollector extends DefaultExpressionVisitor<Void, CollectorContext> {
private StatementContext statementContext;
private boolean bottomPredicate;
private Multimap<Integer, CollectAccessPathResult> slotToAccessPaths;
private Stack<Map<String, Expression>> nameToLambdaArguments = new Stack<>();
public AccessPathExpressionCollector(
StatementContext statementContext, Multimap<Integer, CollectAccessPathResult> slotToAccessPaths,
boolean bottomPredicate) {
this.statementContext = statementContext;
this.slotToAccessPaths = slotToAccessPaths;
this.bottomPredicate = bottomPredicate;
}
public void collect(Expression expression) {
expression.accept(this, new CollectorContext(statementContext, bottomPredicate));
}
private Void continueCollectAccessPath(Expression expr, CollectorContext context) {
return expr.accept(this, context);
}
@Override
public Void visit(Expression expr, CollectorContext context) {
for (Expression child : expr.children()) {
child.accept(this, new CollectorContext(context.statementContext, context.bottomFilter));
}
return null;
}
@Override
public Void visitSlotReference(SlotReference slotReference, CollectorContext context) {
DataType dataType = slotReference.getDataType();
if (dataType instanceof NestedColumnPrunable) {
context.accessPathBuilder.addPrefix(slotReference.getName().toLowerCase());
ImmutableList<String> path = Utils.fastToImmutableList(context.accessPathBuilder.accessPath);
int slotId = slotReference.getExprId().asInt();
slotToAccessPaths.put(slotId, new CollectAccessPathResult(path, context.bottomFilter, context.type));
}
return null;
}
@Override
public Void visitArrayItemSlot(ArrayItemSlot arrayItemSlot, CollectorContext context) {
if (nameToLambdaArguments.isEmpty()) {
return null;
}
context.accessPathBuilder.addPrefix("*");
Expression argument = nameToLambdaArguments.peek().get(arrayItemSlot.getName());
if (argument == null) {
return null;
}
return continueCollectAccessPath(argument, context);
}
@Override
public Void visitAlias(Alias alias, CollectorContext context) {
return alias.child(0).accept(this, context);
}
@Override
public Void visitCast(Cast cast, CollectorContext context) {
if (!context.accessPathBuilder.isEmpty()
&& cast.getDataType() instanceof NestedColumnPrunable
&& cast.child().getDataType() instanceof NestedColumnPrunable) {
DataTypeAccessTree castTree = DataTypeAccessTree.of(cast.getDataType(), TAccessPathType.DATA);
DataTypeAccessTree originTree = DataTypeAccessTree.of(cast.child().getDataType(), TAccessPathType.DATA);
List<String> replacePath = new ArrayList<>(context.accessPathBuilder.getPathList());
if (originTree.replacePathByAnotherTree(castTree, replacePath, 0)) {
CollectorContext castContext = new CollectorContext(context.statementContext, context.bottomFilter);
castContext.accessPathBuilder.accessPath.addAll(replacePath);
return continueCollectAccessPath(cast.child(), castContext);
}
}
return cast.child(0).accept(this,
new CollectorContext(context.statementContext, context.bottomFilter)
);
}
// array element at
@Override
public Void visitElementAt(ElementAt elementAt, CollectorContext context) {
List<Expression> arguments = elementAt.getArguments();
Expression first = arguments.get(0);
if (first.getDataType().isArrayType() || first.getDataType().isMapType()) {
context.accessPathBuilder.addPrefix("*");
continueCollectAccessPath(first, context);
for (int i = 1; i < arguments.size(); i++) {
visit(arguments.get(i), context);
}
return null;
} else {
return visit(elementAt, context);
}
}
// struct element_at
@Override
public Void visitStructElement(StructElement structElement, CollectorContext context) {
List<Expression> arguments = structElement.getArguments();
Expression struct = arguments.get(0);
Expression fieldName = arguments.get(1);
DataType fieldType = fieldName.getDataType();
if (fieldName.isLiteral() && (fieldType.isIntegerLikeType() || fieldType.isStringLikeType())) {
if (fieldType.isIntegerLikeType()) {
int fieldIndex = ((Number) ((Literal) fieldName).getValue()).intValue();
List<StructField> fields = ((StructType) struct.getDataType()).getFields();
if (fieldIndex >= 1 && fieldIndex <= fields.size()) {
String realFieldName = fields.get(fieldIndex - 1).getName();
context.accessPathBuilder.addPrefix(realFieldName);
return continueCollectAccessPath(struct, context);
}
}
context.accessPathBuilder.addPrefix(((Literal) fieldName).getStringValue().toLowerCase());
return continueCollectAccessPath(struct, context);
}
for (Expression argument : arguments) {
visit(argument, context);
}
return null;
}
@Override
public Void visitMapKeys(MapKeys mapKeys, CollectorContext context) {
context = new CollectorContext(context.statementContext, context.bottomFilter);
context.accessPathBuilder.addPrefix("KEYS");
return continueCollectAccessPath(mapKeys.getArgument(0), context);
}
@Override
public Void visitMapValues(MapValues mapValues, CollectorContext context) {
LinkedList<String> suffixPath = context.accessPathBuilder.accessPath;
if (!suffixPath.isEmpty() && suffixPath.get(0).equals("*")) {
CollectorContext removeStarContext
= new CollectorContext(context.statementContext, context.bottomFilter);
removeStarContext.accessPathBuilder.accessPath.addAll(suffixPath.subList(1, suffixPath.size()));
removeStarContext.accessPathBuilder.addPrefix("VALUES");
return continueCollectAccessPath(mapValues.getArgument(0), removeStarContext);
}
context.accessPathBuilder.addPrefix("VALUES");
return continueCollectAccessPath(mapValues.getArgument(0), context);
}
@Override
public Void visitMapContainsKey(MapContainsKey mapContainsKey, CollectorContext context) {
context.accessPathBuilder.addPrefix("KEYS");
return continueCollectAccessPath(mapContainsKey.getArgument(0), context);
}
@Override
public Void visitMapContainsValue(MapContainsValue mapContainsValue, CollectorContext context) {
context.accessPathBuilder.addPrefix("VALUES");
return continueCollectAccessPath(mapContainsValue.getArgument(0), context);
}
@Override
public Void visitMapContainsEntry(MapContainsEntry mapContainsEntry, CollectorContext context) {
context.accessPathBuilder.addPrefix("*");
return continueCollectAccessPath(mapContainsEntry.getArgument(0), context);
}
@Override
public Void visitArrayMap(ArrayMap arrayMap, CollectorContext context) {
// ARRAY_MAP(lambda, <arr> [ , <arr> ... ] )
Expression argument = arrayMap.getArgument(0);
if ((argument instanceof Lambda)) {
return collectArrayPathInLambda((Lambda) argument, context);
}
return visit(arrayMap, context);
}
@Override
public Void visitArrayCount(ArrayCount arrayCount, CollectorContext context) {
// ARRAY_COUNT(<lambda>, <arr>[, ... ])
Expression argument = arrayCount.getArgument(0);
if ((argument instanceof Lambda)) {
return collectArrayPathInLambda((Lambda) argument, context);
}
return visit(arrayCount, context);
}
@Override
public Void visitArrayExists(ArrayExists arrayExists, CollectorContext context) {
// ARRAY_EXISTS([ <lambda>, ] <arr1> [, <arr2> , ...] )
Expression argument = arrayExists.getArgument(0);
if ((argument instanceof Lambda)) {
return collectArrayPathInLambda((Lambda) argument, context);
}
return visit(arrayExists, context);
}
@Override
public Void visitArrayFilter(ArrayFilter arrayFilter, CollectorContext context) {
// ARRAY_FILTER(<lambda>, <arr>)
Expression argument = arrayFilter.getArgument(0);
if ((argument instanceof Lambda)) {
collectArrayPathInLambda((Lambda) argument, context);
}
return visit(arrayFilter, context);
}
@Override
public Void visitArrayFirst(ArrayFirst arrayFirst, CollectorContext context) {
// ARRAY_FIRST(<lambda>, <arr>)
Expression argument = arrayFirst.getArgument(0);
if ((argument instanceof Lambda)) {
collectArrayPathInLambda((Lambda) argument, context);
}
return visit(arrayFirst, context);
}
@Override
public Void visitArrayFirstIndex(ArrayFirstIndex arrayFirstIndex, CollectorContext context) {
// ARRAY_FIRST_INDEX(<lambda>, <arr> [, ...])
Expression argument = arrayFirstIndex.getArgument(0);
if ((argument instanceof Lambda)) {
collectArrayPathInLambda((Lambda) argument, context);
}
return visit(arrayFirstIndex, context);
}
@Override
public Void visitArrayLast(ArrayLast arrayLast, CollectorContext context) {
// ARRAY_LAST(<lambda>, <arr>)
Expression argument = arrayLast.getArgument(0);
if ((argument instanceof Lambda)) {
collectArrayPathInLambda((Lambda) argument, context);
}
return visit(arrayLast, context);
}
@Override
public Void visitArrayLastIndex(ArrayLastIndex arrayLastIndex, CollectorContext context) {
// ARRAY_LAST_INDEX(<lambda>, <arr> [, ...])
Expression argument = arrayLastIndex.getArgument(0);
if ((argument instanceof Lambda)) {
collectArrayPathInLambda((Lambda) argument, context);
}
return visit(arrayLastIndex, context);
}
@Override
public Void visitArrayMatchAny(ArrayMatchAny arrayMatchAny, CollectorContext context) {
// array_match_any(lambda, <arr> [, <arr> ...])
Expression argument = arrayMatchAny.getArgument(0);
if ((argument instanceof Lambda)) {
collectArrayPathInLambda((Lambda) argument, context);
}
return visit(arrayMatchAny, context);
}
@Override
public Void visitArrayMatchAll(ArrayMatchAll arrayMatchAll, CollectorContext context) {
// array_match_all(lambda, <arr> [, <arr> ...])
Expression argument = arrayMatchAll.getArgument(0);
if ((argument instanceof Lambda)) {
collectArrayPathInLambda((Lambda) argument, context);
}
return visit(arrayMatchAll, context);
}
@Override
public Void visitArrayReverseSplit(ArrayReverseSplit arrayReverseSplit, CollectorContext context) {
// ARRAY_REVERSE_SPLIT(<lambda>, <arr> [, ...])
Expression argument = arrayReverseSplit.getArgument(0);
if ((argument instanceof Lambda)) {
collectArrayPathInLambda((Lambda) argument, context);
}
return visit(arrayReverseSplit, context);
}
@Override
public Void visitArraySplit(ArraySplit arraySplit, CollectorContext context) {
// ARRAY_SPLIT(<lambda>, arr [, ...])
Expression argument = arraySplit.getArgument(0);
if ((argument instanceof Lambda)) {
collectArrayPathInLambda((Lambda) argument, context);
}
return visit(arraySplit, context);
}
@Override
public Void visitArraySortBy(ArraySortBy arraySortBy, CollectorContext context) {
// ARRAY_SORTBY(<lambda>, <arr> [, ...])
Expression argument = arraySortBy.getArgument(0);
if ((argument instanceof Lambda)) {
collectArrayPathInLambda((Lambda) argument, context);
}
return visit(arraySortBy, context);
}
// @Override
// public Void visitIsNull(IsNull isNull, CollectorContext context) {
// if (context.accessPathBuilder.isEmpty()) {
// context.setType(TAccessPathType.META);
// return continueCollectAccessPath(isNull.child(), context);
// }
// return visit(isNull, context);
// }
private Void collectArrayPathInLambda(Lambda lambda, CollectorContext context) {
List<Expression> arguments = lambda.getArguments();
Map<String, Expression> nameToArray = Maps.newLinkedHashMap();
for (Expression argument : arguments) {
if (argument instanceof ArrayItemReference) {
nameToArray.put(((ArrayItemReference) argument).getName(), argument.child(0));
}
}
List<String> path = context.accessPathBuilder.getPathList();
if (!path.isEmpty() && path.get(0).equals("*")) {
context.accessPathBuilder.removePrefix();
}
nameToLambdaArguments.push(nameToArray);
try {
continueCollectAccessPath(arguments.get(0), context);
} finally {
nameToLambdaArguments.pop();
}
return null;
}
/** CollectorContext */
public static class CollectorContext {
private StatementContext statementContext;
private AccessPathBuilder accessPathBuilder;
private boolean bottomFilter;
private TAccessPathType type;
public CollectorContext(StatementContext statementContext, boolean bottomFilter) {
this.statementContext = statementContext;
this.accessPathBuilder = new AccessPathBuilder();
this.bottomFilter = bottomFilter;
this.type = TAccessPathType.DATA;
}
public TAccessPathType getType() {
return type;
}
public void setType(TAccessPathType type) {
this.type = type;
}
}
private static class AccessPathBuilder {
private LinkedList<String> accessPath;
public AccessPathBuilder() {
accessPath = new LinkedList<>();
}
public AccessPathBuilder addPrefix(String prefix) {
accessPath.addFirst(prefix);
return this;
}
public AccessPathBuilder removePrefix() {
accessPath.removeFirst();
return this;
}
public List<String> getPathList() {
return accessPath;
}
public boolean isEmpty() {
return accessPath.isEmpty();
}
@Override
public String toString() {
return String.join(".", accessPath);
}
}
/** AccessPathIsPredicate */
public static class CollectAccessPathResult {
private final List<String> path;
private final boolean isPredicate;
private final TAccessPathType type;
public CollectAccessPathResult(List<String> path, boolean isPredicate, TAccessPathType type) {
this.path = path;
this.isPredicate = isPredicate;
this.type = type;
}
public TAccessPathType getType() {
return type;
}
public List<String> getPath() {
return path;
}
public boolean isPredicate() {
return isPredicate;
}
@Override
public String toString() {
return String.join(".", path) + ", " + isPredicate;
}
@Override
public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) {
return false;
}
CollectAccessPathResult that = (CollectAccessPathResult) o;
return isPredicate == that.isPredicate && Objects.equals(path, that.path);
}
@Override
public int hashCode() {
return path.hashCode();
}
}
}