PredicatePushDown.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.planner;
import org.apache.doris.analysis.Analyzer;
import org.apache.doris.analysis.BinaryPredicate;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.InPredicate;
import org.apache.doris.analysis.JoinOperator;
import org.apache.doris.analysis.Predicate;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.analysis.TupleId;
import org.apache.doris.common.Pair;
import com.google.common.base.Joiner;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.util.List;
/**
* Due to the current architecture, predicate derivation at rewrite cannot satisfy all cases,
* because rewrite is performed on first and then where, and when there are subqueries, all cases cannot be derived.
* So keep the predicate pushdown method here.
*
* <p>
* eg:
* origin: select * from t1 left join t2 on t1 = t2 where t1 = 1;
* after: The function will be derived t2 = 1
* </p>
*
*/
public class PredicatePushDown {
private static final Logger LOG = LogManager.getLogger(PredicatePushDown.class);
/**
* Desc: Predicate pushdown for inner and left join.
* @param scanNode ScanNode to be judged
* @param joinOp join Operator
* @param analyzer global context
* @return {@link PlanNode}
*/
public static PlanNode visitScanNode(ScanNode scanNode, JoinOperator joinOp, Analyzer analyzer) {
switch (joinOp) {
case INNER_JOIN:
case LEFT_OUTER_JOIN:
predicateFromLeftSidePropagatesToRightSide(scanNode, analyzer);
break;
// TODO
default:
break;
}
return scanNode;
}
private static void predicateFromLeftSidePropagatesToRightSide(ScanNode scanNode, Analyzer analyzer) {
List<TupleId> tupleIdList = scanNode.getTupleIds();
if (tupleIdList.size() != 1) {
LOG.info("The predicate pushdown is not reflected "
+ "because the scan node involves more then one tuple:{}",
Joiner.on(",").join(tupleIdList));
return;
}
TupleId rightSideTuple = tupleIdList.get(0);
List<Expr> unassignedRightSideConjuncts = analyzer.getUnassignedConjuncts(scanNode);
List<Expr> eqJoinPredicates = analyzer.getEqJoinConjuncts(rightSideTuple);
if (eqJoinPredicates != null) {
List<Expr> allConjuncts = analyzer.getConjuncts(analyzer.getAllTupleIds());
allConjuncts.removeAll(unassignedRightSideConjuncts);
for (Expr conjunct : allConjuncts) {
if (!Predicate.canPushDownPredicate(conjunct)) {
continue;
}
for (Expr eqJoinPredicate : eqJoinPredicates) {
// we can ensure slot is left node, because NormalizeBinaryPredicatesRule
SlotRef otherSlot = conjunct.getChild(0).unwrapSlotRef();
// ensure the children for eqJoinPredicate both be SlotRef
if (eqJoinPredicate.getChild(0).unwrapSlotRef() == null
|| eqJoinPredicate.getChild(1).unwrapSlotRef() == null) {
continue;
}
SlotRef leftSlot = eqJoinPredicate.getChild(0).unwrapSlotRef();
SlotRef rightSlot = eqJoinPredicate.getChild(1).unwrapSlotRef();
// ensure the type is match
if (!leftSlot.getDesc().getType().matchesType(rightSlot.getDesc().getType())) {
continue;
}
// example: t1.id = t2.id and t1.id = 1 => t2.id =1
if (otherSlot.isBound(leftSlot.getSlotId())
&& rightSlot.isBound(rightSideTuple)) {
Expr pushDownConjunct = rewritePredicate(analyzer, conjunct, rightSlot);
if (LOG.isDebugEnabled()) {
LOG.debug("pushDownConjunct: {}", pushDownConjunct);
}
if (!analyzer.getGlobalInDeDuplication().contains(pushDownConjunct)
&& !analyzer.getGlobalSlotToLiteralDeDuplication()
.contains(Pair.of(pushDownConjunct.getChild(0), pushDownConjunct.getChild(1)))) {
scanNode.addConjunct(pushDownConjunct);
}
} else if (otherSlot.isBound(rightSlot.getSlotId())
&& leftSlot.isBound(rightSideTuple)) {
Expr pushDownConjunct = rewritePredicate(analyzer, conjunct, leftSlot);
if (LOG.isDebugEnabled()) {
LOG.debug("pushDownConjunct: {}", pushDownConjunct);
}
if (!analyzer.getGlobalInDeDuplication().contains(pushDownConjunct)
&& !analyzer.getGlobalSlotToLiteralDeDuplication()
.contains(Pair.of(pushDownConjunct.getChild(0), pushDownConjunct.getChild(1)))) {
scanNode.addConjunct(pushDownConjunct);
}
}
}
}
}
}
// TODO: (minghong) here is a bug. For example, this is a left join, we cannot infer "t2.id = 1"
// by "t1.id=1" and "t1.id=t2.id".
// we should not do inference work here. it should be done in some rule like InferFilterRule.
// Rewrite the oldPredicate with new leftChild
// For example: oldPredicate is t1.id = 1, leftChild is t2.id, will return t2.id = 1
private static Expr rewritePredicate(Analyzer analyzer, Expr oldPredicate, Expr leftChild) {
if (oldPredicate instanceof BinaryPredicate) {
BinaryPredicate oldBP = (BinaryPredicate) oldPredicate;
BinaryPredicate bp = new BinaryPredicate(oldBP.getOp(), leftChild, oldBP.getChild(1));
bp.analyzeNoThrow(analyzer);
return bp;
}
if (oldPredicate instanceof InPredicate) {
InPredicate oldIP = (InPredicate) oldPredicate;
InPredicate ip = new InPredicate(leftChild, oldIP.getListChildren(), oldIP.isNotIn());
ip.analyzeNoThrow(analyzer);
return ip;
}
return oldPredicate;
}
}