ArrayContainToArrayOverlap.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rules;
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.rules.expression.ExpressionRuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayContains;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraysOverlap;
import org.apache.doris.nereids.trees.expressions.literal.ArrayLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.Lists;
import com.google.common.collect.Multimaps;
import com.google.common.collect.SetMultimap;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map.Entry;
import java.util.Set;
/**
* array_contains ( c_array, '1' )
* OR array_contains ( c_array, '2' )
* =========================================>
* array_overlap(c_array, ['1', '2'])
*/
public class ArrayContainToArrayOverlap implements ExpressionPatternRuleFactory {
public static final ArrayContainToArrayOverlap INSTANCE = new ArrayContainToArrayOverlap();
private static final int REWRITE_PREDICATE_THRESHOLD = 2;
@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
matchesTopType(Or.class).then(ArrayContainToArrayOverlap::rewrite)
.toRule(ExpressionRuleType.ARRAY_CONTAIN_TO_ARRAY_OVERLAP)
);
}
private static Expression rewrite(Or or) {
List<Expression> disjuncts = ExpressionUtils.extractDisjunction(or);
List<Expression> contains = Lists.newArrayList();
List<Expression> others = Lists.newArrayList();
for (Expression expr : disjuncts) {
if (ArrayContainToArrayOverlap.isValidArrayContains(expr)) {
contains.add(expr);
} else {
others.add(expr);
}
}
if (contains.size() <= 1) {
return or;
}
SetMultimap<Expression, Literal> containLiteralSet = Multimaps.newSetMultimap(
new LinkedHashMap<>(), LinkedHashSet::new
);
for (Expression contain : contains) {
containLiteralSet.put(contain.child(0), (Literal) contain.child(1));
}
Builder<Expression> newDisjunctsBuilder = new ImmutableList.Builder<>();
for (Entry<Expression, Collection<Literal>> kv : containLiteralSet.asMap().entrySet()) {
Expression left = kv.getKey();
Collection<Literal> literalSet = kv.getValue();
if (literalSet.size() > REWRITE_PREDICATE_THRESHOLD) {
newDisjunctsBuilder.add(
new ArraysOverlap(left, new ArrayLiteral(Utils.fastToImmutableList(literalSet)))
);
}
}
for (Expression contain : contains) {
if (!canCovertToArrayOverlap(contain, containLiteralSet)) {
newDisjunctsBuilder.add(contain);
}
}
newDisjunctsBuilder.addAll(others);
return ExpressionUtils.or(newDisjunctsBuilder.build());
}
private static boolean isValidArrayContains(Expression expression) {
return expression instanceof ArrayContains && expression.child(1) instanceof Literal;
}
private static boolean canCovertToArrayOverlap(
Expression expression, SetMultimap<Expression, Literal> containLiteralSet) {
if (!(expression instanceof ArrayContains)) {
return false;
}
Set<Literal> containLiteral = containLiteralSet.get(expression.child(0));
return containLiteral.size() > REWRITE_PREDICATE_THRESHOLD;
}
}