Utils.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.util;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.glue.LogicalPlanAdapter;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.plans.commands.info.AliasInfo;
import org.apache.doris.nereids.trees.plans.commands.info.TableNameInfo;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.StmtExecutor;
import org.apache.doris.statistics.ResultRow;
import com.google.common.base.CaseFormat;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import org.apache.commons.lang3.StringUtils;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.StringJoiner;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* Utils for Nereids.
*/
public class Utils {
/**
* Quoted string if it contains special character or all characters are digit.
*
* @param part string to be quoted
* @return quoted string
*/
public static String quoteIfNeeded(String part) {
// We quote strings except the ones which consist of digits only.
StringBuilder quote = new StringBuilder(part.length());
for (int i = 0; i < part.length(); i++) {
char c = part.charAt(i);
if (c == '`') {
quote.append("``");
} else {
quote.append(c);
}
}
return quote.toString();
}
/**
* Helper function to eliminate unnecessary checked exception caught requirement from the main logic of translator.
*
* @param f function which would invoke the logic of
* stale code from old optimizer that could throw
* a checked exception.
*/
public static void execWithUncheckedException(FuncWrapper f) {
try {
f.exec();
} catch (Exception e) {
throw new RuntimeException(e.getMessage(), e);
}
}
/**
* Helper function to eliminate unnecessary checked exception caught requirement from the main logic of translator.
*/
@SuppressWarnings("unchecked")
public static <R> R execWithReturnVal(Supplier<R> f) {
final Object[] ans = new Object[] {null};
try {
ans[0] = f.get();
} catch (Exception e) {
throw new RuntimeException(e.getMessage(), e);
}
return (R) ans[0];
}
/**
* Check whether lhs and rhs are intersecting.
*/
public static <T> boolean isIntersecting(Set<T> lhs, Collection<T> rhs) {
for (T rh : rhs) {
if (lhs.contains(rh)) {
return true;
}
}
return false;
}
/**
* Wrapper to a function without return value.
*/
public interface FuncWrapper {
void exec() throws Exception;
}
/**
* Wrapper to a function with return value.
*/
public interface Supplier<R> {
R get() throws Exception;
}
/**
* Fully qualified identifier name parts, i.e., concat qualifier and name into a list.
*/
public static List<String> qualifiedNameParts(List<String> qualifier, String name) {
return new ImmutableList.Builder<String>().addAll(qualifier).add(name).build();
}
/**
* Fully qualified identifier name, concat qualifier and name with `.` as separator.
*/
public static String qualifiedName(List<String> qualifier, String name) {
return StringUtils.join(qualifiedNameParts(qualifier, name), ".");
}
/**
* get qualified name with Backtick
*/
public static String qualifiedNameWithBackquote(List<String> qualifiers, String name) {
List<String> fullName = new ArrayList<>(qualifiers);
fullName.add(name);
return qualifiedNameWithBackquote(fullName);
}
/**
* get qualified name with Backtick
*/
public static String qualifiedNameWithBackquote(List<String> qualifiers) {
List<String> qualifierWithBackquote = Lists.newArrayListWithCapacity(qualifiers.size());
for (String qualifier : qualifiers) {
String escapeQualifier = qualifier.replace("`", "``");
qualifierWithBackquote.add('`' + escapeQualifier + '`');
}
return StringUtils.join(qualifierWithBackquote, ".");
}
/**
* Get sql string for plan.
*
* @param planName name of plan, like LogicalJoin.
* @param variables variable needed to add into sqlString.
* @return the string of PlanNode.
*/
public static String toSqlString(String planName, Object... variables) {
Preconditions.checkState(variables.length % 2 == 0);
StringBuilder stringBuilder = new StringBuilder();
stringBuilder.append(planName).append(" ( ");
if (variables.length == 0) {
return stringBuilder.append(" )").toString();
}
for (int i = 0; i < variables.length - 1; i += 2) {
if (!"".equals(toStringOrNull(variables[i + 1]))) {
if (i != 0) {
stringBuilder.append(", ");
}
stringBuilder.append(toStringOrNull(variables[i])).append("=").append(toStringOrNull(variables[i + 1]));
}
}
return stringBuilder.append(" )").toString();
}
public static String toStringOrNull(Object obj) {
return obj == null ? "null" : obj.toString();
}
/**
* Get the unCorrelated exprs that belong to the subquery,
* that is, the unCorrelated exprs that can be resolved within the subquery.
* eg:
* select * from t1 where t1.a = (select sum(t2.b) from t2 where t1.c = abs(t2.d));
* correlatedPredicates : t1.c = abs(t2.d)
* unCorrelatedExprs : abs(t2.d)
* return abs(t2.d)
*/
public static List<Expression> getUnCorrelatedExprs(List<Expression> correlatedPredicates,
List<Expression> correlatedSlots) {
List<Expression> unCorrelatedExprs = new ArrayList<>();
correlatedPredicates.forEach(predicate -> {
if (!(predicate instanceof BinaryExpression) && (!(predicate instanceof Not)
|| !(predicate.child(0) instanceof BinaryExpression))) {
throw new AnalysisException(
"Unsupported correlated subquery with correlated predicate "
+ predicate.toString());
}
BinaryExpression binaryExpression;
if (predicate instanceof Not) {
binaryExpression = (BinaryExpression) ((Not) predicate).child();
} else {
binaryExpression = (BinaryExpression) predicate;
}
Expression left = binaryExpression.left();
Expression right = binaryExpression.right();
Set<Slot> leftInputSlots = left.getInputSlots();
Set<Slot> rightInputSlots = right.getInputSlots();
boolean correlatedToLeft = !leftInputSlots.isEmpty()
&& leftInputSlots.stream().allMatch(correlatedSlots::contains)
&& rightInputSlots.stream().noneMatch(correlatedSlots::contains);
boolean correlatedToRight = !rightInputSlots.isEmpty()
&& rightInputSlots.stream().allMatch(correlatedSlots::contains)
&& leftInputSlots.stream().noneMatch(correlatedSlots::contains);
if (!correlatedToLeft && !correlatedToRight) {
throw new AnalysisException(
"Unsupported correlated subquery with correlated predicate " + predicate);
} else if (correlatedToLeft && !rightInputSlots.isEmpty()) {
unCorrelatedExprs.add(right);
} else if (correlatedToRight && !leftInputSlots.isEmpty()) {
unCorrelatedExprs.add(left);
}
});
return unCorrelatedExprs;
}
private static List<Expression> collectCorrelatedSlotsFromChildren(
BinaryExpression binaryExpression, List<Expression> correlatedSlots) {
List<Expression> slots = new ArrayList<>();
if (binaryExpression.left().anyMatch(correlatedSlots::contains)) {
if (binaryExpression.right() instanceof SlotReference) {
slots.add(binaryExpression.right());
} else if (binaryExpression.right() instanceof Cast) {
slots.add(((Cast) binaryExpression.right()).child());
}
} else {
if (binaryExpression.left() instanceof SlotReference) {
slots.add(binaryExpression.left());
} else if (binaryExpression.left() instanceof Cast) {
slots.add(((Cast) binaryExpression.left()).child());
}
}
return slots;
}
public static Map<Boolean, List<Expression>> splitCorrelatedConjuncts(
Set<Expression> conjuncts, List<Expression> slots) {
return conjuncts.stream().collect(Collectors.partitioningBy(
expr -> expr.anyMatch(slots::contains)));
}
/**
* Replace one item in a list with another item.
*/
public static <T> void replaceList(List<T> list, T oldItem, T newItem) {
boolean result = false;
for (int i = 0; i < list.size(); i++) {
if (list.get(i).equals(oldItem)) {
list.set(i, newItem);
result = true;
}
}
Preconditions.checkState(result);
}
/**
* Remove item from a list without equals method.
*/
public static <T> void identityRemove(List<T> list, T item) {
for (int i = 0; i < list.size(); i++) {
if (list.get(i) == item) {
list.remove(i);
i--;
return;
}
}
Preconditions.checkState(false, "item not found in list");
}
/**
* allCombinations
*/
public static <T> List<List<T>> allCombinations(List<List<T>> lists) {
if (lists.size() == 1) {
List<T> first = lists.get(0);
if (first.size() == 1) {
return lists;
}
List<List<T>> result = Lists.newArrayListWithCapacity(lists.size());
for (T item : first) {
result.add(ImmutableList.of(item));
}
return result;
} else {
return doAllCombinations(lists);
}
}
private static <T> List<List<T>> doAllCombinations(List<List<T>> lists) {
int size = lists.size();
if (size == 0) {
return ImmutableList.of();
}
List<T> first = lists.get(0);
if (size == 1) {
return first
.stream()
.map(ImmutableList::of)
.collect(ImmutableList.toImmutableList());
}
List<List<T>> rest = lists.subList(1, size);
List<List<T>> combinationWithoutFirst = allCombinations(rest);
return first.stream()
.flatMap(firstValue -> combinationWithoutFirst.stream()
.map(restList ->
Stream.concat(Stream.of(firstValue), restList.stream())
.collect(ImmutableList.toImmutableList())
)
).collect(ImmutableList.toImmutableList());
}
/** getAllCombinations */
public static <T> List<List<T>> getAllCombinations(List<T> list, int itemNum) {
List<List<T>> result = Lists.newArrayList();
generateCombinations(list, itemNum, 0, Lists.newArrayList(), result);
return result;
}
private static <T> void generateCombinations(
List<T> list, int n, int start, List<T> current, List<List<T>> result) {
if (current.size() == n) {
result.add(new ArrayList<>(current));
return;
}
for (int i = start; i < list.size(); i++) {
current.add(list.get(i));
generateCombinations(list, n, i + 1, current, result);
current.remove(current.size() - 1);
}
}
public static <T> List<List<T>> allPermutations(List<T> list) {
List<List<T>> result = new ArrayList<>();
generatePermutations(new ArrayList<>(list), new ArrayList<>(), result);
return result;
}
private static <T> void generatePermutations(List<T> list, List<T> current, List<List<T>> result) {
if (!current.isEmpty()) {
result.add(new ArrayList<>(current));
}
for (int i = 0; i < list.size(); i++) {
T element = list.remove(i);
current.add(element);
generatePermutations(list, current, result);
current.remove(current.size() - 1);
list.add(i, element);
}
}
/** permutations */
public static <T> List<List<T>> permutations(List<T> list) {
list = new ArrayList<>(list);
List<List<T>> result = new ArrayList<>();
if (list.isEmpty()) {
result.add(new ArrayList<>());
return result;
}
T firstElement = list.get(0);
List<T> rest = list.subList(1, list.size());
List<List<T>> recursivePermutations = permutations(rest);
for (List<T> smallerPermutated : recursivePermutations) {
for (int index = 0; index <= smallerPermutated.size(); index++) {
List<T> temp = new ArrayList<>(smallerPermutated);
temp.add(index, firstElement);
result.add(temp);
}
}
return result;
}
public static <T> List<T> copyRequiredList(List<T> list) {
return ImmutableList.copyOf(Objects.requireNonNull(list, "non-null list is required"));
}
public static <T> List<T> copyRequiredMutableList(List<T> list) {
return Lists.newArrayList(Objects.requireNonNull(list, "non-null list is required"));
}
/**
* Normalize the name to lower underscore style, return default name if the name is empty.
*/
public static String normalizeName(String name, String defaultName) {
if (StringUtils.isEmpty(name)) {
return defaultName;
}
if (name.contains("$")) {
name = name.replace("$", "_");
}
return CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_UNDERSCORE, name);
}
/**
* Check the content if contains chinese or not, if true when contains chinese or false
*/
public static boolean containChinese(String text) {
for (char textChar : text.toCharArray()) {
if (Character.UnicodeScript.of(textChar) == Character.UnicodeScript.HAN) {
return true;
}
}
return false;
}
public static <I, O> List<O> fastMapList(List<I> list, int additionSize, Function<I, O> transformer) {
List<O> newList = Lists.newArrayListWithCapacity(list.size() + additionSize);
for (I input : list) {
newList.add(transformer.apply(input));
}
return newList;
}
/**
* fastToImmutableList
*/
public static <E> ImmutableList<E> fastToImmutableList(E[] array) {
switch (array.length) {
case 0:
return ImmutableList.of();
case 1:
return ImmutableList.of(array[0]);
default:
// NOTE: ImmutableList.copyOf(array) has additional clone of the array, so here we
// direct generate a ImmutableList
Builder<E> copyChildren = ImmutableList.builderWithExpectedSize(array.length);
for (E child : array) {
copyChildren.add(child);
}
return copyChildren.build();
}
}
/**
* fastToImmutableList
*/
public static <E> ImmutableList<E> fastToImmutableList(Collection<? extends E> collection) {
if (collection instanceof ImmutableList) {
return (ImmutableList<E>) collection;
}
switch (collection.size()) {
case 0:
return ImmutableList.of();
case 1:
return collection instanceof List
? ImmutableList.of(((List<E>) collection).get(0))
: ImmutableList.of(collection.iterator().next());
default: {
// NOTE: ImmutableList.copyOf(list) has additional clone of the list, so here we
// direct generate a ImmutableList
Builder<E> copyChildren = ImmutableList.builderWithExpectedSize(collection.size());
copyChildren.addAll(collection);
return copyChildren.build();
}
}
}
/**
* fastToImmutableSet
*/
public static <E> ImmutableSet<E> fastToImmutableSet(Collection<? extends E> collection) {
if (collection instanceof ImmutableSet) {
return (ImmutableSet<E>) collection;
}
switch (collection.size()) {
case 0:
return ImmutableSet.of();
case 1:
return collection instanceof List
? ImmutableSet.of(((List<E>) collection).get(0))
: ImmutableSet.of(collection.iterator().next());
default:
// NOTE: ImmutableList.copyOf(array) has additional clone of the array, so here we
// direct generate a ImmutableList
ImmutableSet.Builder<E> copyChildren = ImmutableSet.builderWithExpectedSize(collection.size());
for (E child : collection) {
copyChildren.add(child);
}
return copyChildren.build();
}
}
/**
* reverseImmutableList
*/
public static <E> ImmutableList<E> reverseImmutableList(List<? extends E> list) {
Builder<E> reverseList = ImmutableList.builderWithExpectedSize(list.size());
for (int i = list.size() - 1; i >= 0; i--) {
reverseList.add(list.get(i));
}
return reverseList.build();
}
/**
* filterImmutableList
*/
public static <E> ImmutableList<E> filterImmutableList(List<? extends E> list, Predicate<E> filter) {
Builder<E> newList = ImmutableList.builderWithExpectedSize(list.size());
for (int i = 0; i < list.size(); i++) {
E item = list.get(i);
if (filter.test(item)) {
newList.add(item);
}
}
return newList.build();
}
/**
* concatToSet
*/
public static <E> Set<E> concatToSet(Collection<? extends E> left, Collection<? extends E> right) {
ImmutableSet.Builder<E> required = ImmutableSet.builderWithExpectedSize(
left.size() + right.size()
);
required.addAll(left);
required.addAll(right);
return required.build();
}
/**
* fastReduce
*/
public static <M, T extends M> Optional<M> fastReduce(List<T> list, BiFunction<M, T, M> reduceOp) {
if (list.isEmpty()) {
return Optional.empty();
}
M merge = list.get(0);
for (int i = 1; i < list.size(); i++) {
merge = reduceOp.apply(merge, list.get(i));
}
return Optional.of(merge);
}
/** If the first character of the string is uppercase, replace the first character with lowercase*/
public static String convertFirstChar(String input) {
if (input == null || input.isEmpty()) {
return input;
}
char firstChar = input.charAt(0);
if (Character.isUpperCase(firstChar)) {
firstChar = Character.toLowerCase(firstChar);
} else {
return input;
}
return firstChar + input.substring(1);
}
/** addLinePrefix */
public static String addLinePrefix(String str, String prefix) {
StringBuilder newStr = new StringBuilder((int) (str.length() * 1.2));
String[] lines = str.split("\n");
for (int i = 0; i < lines.length; i++) {
String line = lines[i];
newStr.append(prefix).append(line);
if (i + 1 < lines.length) {
newStr.append("\n");
}
}
return newStr.toString();
}
/**
* Builds a logical plan for a SQL query.
*
* @param selectList the list of columns and their aliases to be selected
* @param tableName the name of the table from which to select
* @param whereClause the where clause to filter the results
* @return the logical plan representing the SQL query
*/
public static LogicalPlan buildLogicalPlan(List<AliasInfo> selectList, TableNameInfo tableName,
String whereClause) {
StringJoiner columnJoiner = new StringJoiner(", ");
for (AliasInfo aliasInfo : selectList) {
columnJoiner.add(aliasInfo.toString());
}
String sql = "SELECT " + columnJoiner.toString() + " FROM " + tableName.toFullyQualified() + " " + whereClause;
return new NereidsParser().parseSingle(sql);
}
/**
* Execute a logical plan and return the results.
*
* @param ctx the context in which to execute the plan
* @param executor the executor to use to execute the plan
* @param plan the plan to execute
* @return the results of executing the plan
*/
public static List<List<String>> executePlan(ConnectContext ctx, StmtExecutor executor, LogicalPlan plan) {
LogicalPlanAdapter adapter = new LogicalPlanAdapter(plan, ctx.getStatementContext());
executor.setParsedStmt(adapter);
List<ResultRow> resultRows = executor.executeInternalQuery();
return resultRows.stream().map(ResultRow::getValues).collect(Collectors.toList());
}
}