WindowFunctionChecker.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.OrderExpression;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.WindowFrame;
import org.apache.doris.nereids.trees.expressions.WindowFrame.FrameBoundType;
import org.apache.doris.nereids.trees.expressions.WindowFrame.FrameBoundary;
import org.apache.doris.nereids.trees.expressions.WindowFrame.FrameUnitsType;
import org.apache.doris.nereids.trees.expressions.functions.window.CumeDist;
import org.apache.doris.nereids.trees.expressions.functions.window.DenseRank;
import org.apache.doris.nereids.trees.expressions.functions.window.FirstOrLastValue;
import org.apache.doris.nereids.trees.expressions.functions.window.FirstValue;
import org.apache.doris.nereids.trees.expressions.functions.window.Lag;
import org.apache.doris.nereids.trees.expressions.functions.window.LastValue;
import org.apache.doris.nereids.trees.expressions.functions.window.Lead;
import org.apache.doris.nereids.trees.expressions.functions.window.NthValue;
import org.apache.doris.nereids.trees.expressions.functions.window.Ntile;
import org.apache.doris.nereids.trees.expressions.functions.window.PercentRank;
import org.apache.doris.nereids.trees.expressions.functions.window.Rank;
import org.apache.doris.nereids.trees.expressions.functions.window.RowNumber;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
/**
* Check and standardize Window expression:
*
* step 1: checkWindowBeforeFunc():
* general checking for WindowFrame, including check OrderKeyList, set default right boundary, check offset if exists,
* check correctness of boundaryType
* step 2: checkWindowFunction():
* check window function, and different function has different checking rules .
* If window frame not exits, set a unique default window frame according to their function type.
* step 3: checkWindowAfterFunc():
* reverse window if necessary (just for first_value() and last_value()), and add a general default
* window frame (RANGE between UNBOUNDED PRECEDING and CURRENT ROW)
*/
public class WindowFunctionChecker extends DefaultExpressionVisitor<Expression, Void> {
private WindowExpression windowExpression;
public WindowFunctionChecker(WindowExpression window) {
this.windowExpression = window;
}
public WindowExpression getWindow() {
return windowExpression;
}
/**
* step 1: check windowFrame in window;
*/
public void checkWindowBeforeFunc() {
windowExpression.getWindowFrame().ifPresent(this::checkWindowFrameBeforeFunc);
}
/**
* step 2: check windowFunction in window
*/
public Expression checkWindowFunction() {
// todo: visitNtile()
// in checkWindowFrameBeforeFunc() we have confirmed that both left and right boundary are set as long as
// windowFrame exists, therefore in all following visitXXX functions we don't need to check whether the right
// boundary is null.
return windowExpression.accept(this, null);
}
/**
* step 3: check window
*/
public void checkWindowAfterFunc() {
Optional<WindowFrame> windowFrame = windowExpression.getWindowFrame();
if (windowFrame.isPresent()) {
// reverse windowFrame
checkWindowFrameAfterFunc(windowFrame.get());
} else {
setDefaultWindowFrameAfterFunc();
}
}
/* ********************************************************************************************
* methods for step 1
* ******************************************************************************************** */
/**
*
* if WindowFrame doesn't have right boundary, we will set it a default one(current row);
* but if WindowFrame itself doesn't exist, we will keep it null still.
*
* Basic exception cases:
* 0. WindowFrame != null, but OrderKeyList == null
*
* WindowFrame EXCEPTION cases:
* 1. (unbounded following, xxx) || (offset following, !following)
* 2. (xxx, unbounded preceding) || (!preceding, offset preceding)
* 3. RANGE && ( (offset preceding, xxx) || (xxx, offset following) || (current row, current row) )
*
* WindowFrame boundOffset check:
* 4. check value of boundOffset: Literal; Positive; Integer (for ROWS) or Numeric (for RANGE)
* 5. check that boundOffset of left <= boundOffset of right
*/
private void checkWindowFrameBeforeFunc(WindowFrame windowFrame) {
// case 0
if (windowExpression.getOrderKeys().isEmpty()) {
throw new AnalysisException("WindowFrame clause requires OrderBy clause");
}
// set default rightBoundary
if (windowFrame.getRightBoundary().isNull()) {
windowFrame = windowFrame.withRightBoundary(FrameBoundary.newCurrentRowBoundary());
}
FrameBoundary left = windowFrame.getLeftBoundary();
FrameBoundary right = windowFrame.getRightBoundary();
// case 1
if (left.getFrameBoundType() == FrameBoundType.UNBOUNDED_FOLLOWING) {
throw new AnalysisException("WindowFrame in any window function cannot use "
+ "UNBOUNDED FOLLOWING as left boundary");
}
if (left.getFrameBoundType() == FrameBoundType.FOLLOWING && !right.asFollowing()) {
throw new AnalysisException("WindowFrame with FOLLOWING left boundary requires "
+ "UNBOUNDED FOLLOWING or FOLLOWING right boundary");
}
// case 2
if (right.getFrameBoundType() == FrameBoundType.UNBOUNDED_PRECEDING) {
throw new AnalysisException("WindowFrame in any window function cannot use "
+ "UNBOUNDED PRECEDING as right boundary");
}
if (right.getFrameBoundType() == FrameBoundType.PRECEDING && !left.asPreceding()) {
throw new AnalysisException("WindowFrame with PRECEDING right boundary requires "
+ "UNBOUNDED PRECEDING or PRECEDING left boundary");
}
// case 3
// this case will be removed when RANGE with offset boundaries is supported
if (windowFrame.getFrameUnits() == FrameUnitsType.RANGE) {
if (left.hasOffset() || right.hasOffset()
|| (left.getFrameBoundType() == FrameBoundType.CURRENT_ROW
&& right.getFrameBoundType() == FrameBoundType.CURRENT_ROW)) {
throw new AnalysisException("WindowFrame with RANGE must use both UNBOUNDED boundary or "
+ "one UNBOUNDED boundary and one CURRENT ROW");
}
}
// case 4
if (left.hasOffset()) {
checkFrameBoundOffset(left);
}
if (right.hasOffset()) {
checkFrameBoundOffset(right);
}
// case 5
// check correctness of left boundary and right boundary
if (left.hasOffset() && right.hasOffset()) {
double leftOffsetValue = ((Literal) left.getBoundOffset().get()).getDouble();
double rightOffsetValue = ((Literal) right.getBoundOffset().get()).getDouble();
if (left.asPreceding() && right.asPreceding()) {
Preconditions.checkArgument(leftOffsetValue >= rightOffsetValue, "WindowFrame with "
+ "PRECEDING boundary requires that leftBoundOffset >= rightBoundOffset");
} else if (left.asFollowing() && right.asFollowing()) {
Preconditions.checkArgument(leftOffsetValue <= rightOffsetValue, "WindowFrame with "
+ "FOLLOWING boundary requires that leftBoundOffset >= rightBoundOffset");
}
}
windowExpression = windowExpression.withWindowFrame(windowFrame);
}
/**
* check boundOffset of FrameBoundary if it exists:
* 1 boundOffset should be Literal, but this restriction can be removed after completing FoldConstant
* 2 boundOffset should be positive
* 2 boundOffset should be a positive INTEGER if FrameUnitsType == ROWS
* 3 boundOffset should be a positive INTEGER or DECIMAL if FrameUnitsType == RANGE
*/
private void checkFrameBoundOffset(FrameBoundary frameBoundary) {
Expression offset = frameBoundary.getBoundOffset().get();
// case 1
Preconditions.checkArgument(offset.isLiteral(), "BoundOffset of WindowFrame must be Literal");
// case 2
boolean isPositive = ((Literal) offset).getDouble() > 0;
Preconditions.checkArgument(isPositive, "BoundOffset of WindowFrame must be positive");
// case 3
FrameUnitsType frameUnits = windowExpression.getWindowFrame().get().getFrameUnits();
if (frameUnits == FrameUnitsType.ROWS) {
Preconditions.checkArgument(offset.getDataType().isIntegralType(),
"BoundOffset of ROWS WindowFrame must be an Integer");
}
// case 4
if (frameUnits == FrameUnitsType.RANGE) {
Preconditions.checkArgument(offset.getDataType().isNumericType(),
"BoundOffset of RANGE WindowFrame must be an Integer or Decimal");
}
}
/* ********************************************************************************************
* methods for step 2
* ******************************************************************************************** */
/**
* required WindowFrame: (UNBOUNDED PRECEDING, offset PRECEDING)
* but in Spark, it is (offset PRECEDING, offset PRECEDING)
*/
@Override
public Lag visitLag(Lag lag, Void ctx) {
// check and complete window frame
windowExpression.getWindowFrame().ifPresent(wf -> {
throw new AnalysisException("WindowFrame for LAG() must be null");
});
if (lag.children().size() != 3) {
throw new AnalysisException("Lag must have three parameters");
}
Expression column = lag.child(0);
Expression offset = lag.getOffset();
Expression defaultValue = lag.getDefaultValue();
WindowFrame requiredFrame = new WindowFrame(FrameUnitsType.ROWS,
FrameBoundary.newPrecedingBoundary(), FrameBoundary.newPrecedingBoundary(offset));
windowExpression = windowExpression.withWindowFrame(requiredFrame);
// check if the class of lag's column matches defaultValue, and cast it
if (!TypeCoercionUtils.implicitCast(column.getDataType(), defaultValue.getDataType()).isPresent()) {
throw new AnalysisException("DefaultValue's Datatype of LAG() cannot match its relevant column. The column "
+ "type is " + column.getDataType() + ", but the defaultValue type is " + defaultValue.getDataType());
}
return lag.withChildren(ImmutableList.of(column, offset,
TypeCoercionUtils.castIfNotMatchType(defaultValue, column.getDataType())));
}
/**
* required WindowFrame: (UNBOUNDED PRECEDING, offset FOLLOWING)
* but in Spark, it is (offset FOLLOWING, offset FOLLOWING)
*/
@Override
public Lead visitLead(Lead lead, Void ctx) {
windowExpression.getWindowFrame().ifPresent(wf -> {
throw new AnalysisException("WindowFrame for LEAD() must be null");
});
if (lead.children().size() != 3) {
throw new AnalysisException("Lead must have three parameters");
}
Expression column = lead.child(0);
Expression offset = lead.getOffset();
Expression defaultValue = lead.getDefaultValue();
WindowFrame requiredFrame = new WindowFrame(FrameUnitsType.ROWS,
FrameBoundary.newPrecedingBoundary(), FrameBoundary.newFollowingBoundary(offset));
windowExpression = windowExpression.withWindowFrame(requiredFrame);
// check if the class of lag's column matches defaultValue, and cast it
if (!TypeCoercionUtils.implicitCast(column.getDataType(), defaultValue.getDataType()).isPresent()) {
throw new AnalysisException("DefaultValue's Datatype of LEAD() can't match its relevant column. The column "
+ "type is " + column.getDataType() + ", but the defaultValue type is " + defaultValue.getDataType());
}
return lead.withChildren(ImmutableList.of(column, offset,
TypeCoercionUtils.castIfNotMatchType(defaultValue, column.getDataType())));
}
/**
* [Copied from class AnalyticExpr.standardize()]:
*
* FIRST_VALUE without UNBOUNDED PRECEDING gets rewritten to use a different window
* and change the function to return the last value. We either set the fn to be
* 'last_value' or 'first_value_rewrite', which simply wraps the 'last_value'
* implementation but allows us to handle the first rows in a partition in a special
* way in the backend. There are a few cases:
* a) Start bound is X FOLLOWING or CURRENT ROW (X=0):
* Use 'last_value' with a window where both bounds are X FOLLOWING (or
* CURRENT ROW). Setting the start bound to X following is necessary because the
* X rows at the end of a partition have no rows in their window. Note that X
* FOLLOWING could be rewritten as lead(X) but that would not work for CURRENT
* ROW.
* b) Start bound is X PRECEDING and end bound is CURRENT ROW or FOLLOWING:
* Use 'first_value_rewrite' and a window with an end bound X PRECEDING. An
* extra parameter '-1' is added to indicate to the backend that NULLs should
* not be added for the first X rows.
* c) Start bound is X PRECEDING and end bound is Y PRECEDING:
* Use 'first_value_rewrite' and a window with an end bound X PRECEDING. The
* first Y rows in a partition have empty windows and should be NULL. An extra
* parameter with the integer constant Y is added to indicate to the backend
* that NULLs should be added for the first Y rows.
*/
@Override
public FirstOrLastValue visitFirstValue(FirstValue firstValue, Void ctx) {
FirstOrLastValue.checkSecondParameter(firstValue);
if (2 == firstValue.arity()) {
if (firstValue.child(1).equals(BooleanLiteral.TRUE)) {
return firstValue;
} else {
firstValue = (FirstValue) firstValue.withChildren(firstValue.child(0));
windowExpression = windowExpression.withFunction(firstValue);
}
}
Optional<WindowFrame> windowFrame = windowExpression.getWindowFrame();
if (windowFrame.isPresent()) {
WindowFrame wf = windowFrame.get();
if (wf.getLeftBoundary().isNot(FrameBoundType.UNBOUNDED_PRECEDING)
&& wf.getLeftBoundary().isNot(FrameBoundType.PRECEDING)) {
windowExpression = windowExpression.withWindowFrame(
wf.withFrameUnits(FrameUnitsType.ROWS).withRightBoundary(wf.getLeftBoundary()));
LastValue lastValue = new LastValue(firstValue.children());
windowExpression = windowExpression.withFunction(lastValue);
return lastValue;
}
if (wf.getLeftBoundary().is(FrameBoundType.UNBOUNDED_PRECEDING)
&& wf.getRightBoundary().isNot(FrameBoundType.PRECEDING)) {
windowExpression = windowExpression.withWindowFrame(
wf.withRightBoundary(FrameBoundary.newCurrentRowBoundary()));
}
} else {
windowExpression = windowExpression.withWindowFrame(new WindowFrame(FrameUnitsType.ROWS,
FrameBoundary.newPrecedingBoundary(), FrameBoundary.newCurrentRowBoundary()));
}
return firstValue;
}
@Override
public FirstOrLastValue visitLastValue(LastValue lastValue, Void ctx) {
FirstOrLastValue.checkSecondParameter(lastValue);
if (2 == lastValue.arity() && lastValue.child(1).equals(BooleanLiteral.FALSE)) {
lastValue = (LastValue) lastValue.withChildren(lastValue.child(0));
windowExpression = windowExpression.withFunction(lastValue);
}
return lastValue;
}
/**
* required WindowFrame: (RANGE, UNBOUNDED PRECEDING, CURRENT ROW)
*/
@Override
public Rank visitRank(Rank rank, Void ctx) {
WindowFrame requiredFrame = new WindowFrame(FrameUnitsType.RANGE,
FrameBoundary.newPrecedingBoundary(), FrameBoundary.newCurrentRowBoundary());
checkAndCompleteWindowFrame(requiredFrame, rank.getName());
return rank;
}
/**
* required WindowFrame: (RANGE, UNBOUNDED PRECEDING, CURRENT ROW)
*/
@Override
public DenseRank visitDenseRank(DenseRank denseRank, Void ctx) {
WindowFrame requiredFrame = new WindowFrame(FrameUnitsType.RANGE,
FrameBoundary.newPrecedingBoundary(), FrameBoundary.newCurrentRowBoundary());
checkAndCompleteWindowFrame(requiredFrame, denseRank.getName());
return denseRank;
}
/**
* required WindowFrame: (RANGE, UNBOUNDED PRECEDING, CURRENT ROW)
*/
@Override
public PercentRank visitPercentRank(PercentRank percentRank, Void ctx) {
WindowFrame requiredFrame = new WindowFrame(FrameUnitsType.RANGE,
FrameBoundary.newPrecedingBoundary(), FrameBoundary.newCurrentRowBoundary());
checkAndCompleteWindowFrame(requiredFrame, percentRank.getName());
return percentRank;
}
/**
* required WindowFrame: (ROWS, UNBOUNDED PRECEDING, CURRENT ROW)
*/
@Override
public RowNumber visitRowNumber(RowNumber rowNumber, Void ctx) {
// check and complete window frame
WindowFrame requiredFrame = new WindowFrame(FrameUnitsType.ROWS,
FrameBoundary.newPrecedingBoundary(), FrameBoundary.newCurrentRowBoundary());
checkAndCompleteWindowFrame(requiredFrame, rowNumber.getName());
return rowNumber;
}
/**
* required WindowFrame: (RANGE, UNBOUNDED PRECEDING, CURRENT ROW)
*/
@Override
public CumeDist visitCumeDist(CumeDist cumeDist, Void ctx) {
WindowFrame requiredFrame = new WindowFrame(FrameUnitsType.RANGE,
FrameBoundary.newPrecedingBoundary(), FrameBoundary.newCurrentRowBoundary());
checkAndCompleteWindowFrame(requiredFrame, cumeDist.getName());
return cumeDist;
}
/**
* required WindowFrame: (ROWS, UNBOUNDED PRECEDING, CURRENT ROW)
*/
@Override
public Ntile visitNtile(Ntile ntile, Void ctx) {
WindowFrame requiredFrame = new WindowFrame(FrameUnitsType.ROWS,
FrameBoundary.newPrecedingBoundary(), FrameBoundary.newCurrentRowBoundary());
checkAndCompleteWindowFrame(requiredFrame, ntile.getName());
return ntile;
}
@Override
public NthValue visitNthValue(NthValue nthValue, Void ctx) {
NthValue.checkSecondParameter(nthValue);
return nthValue;
}
/**
* check if the current WindowFrame equals with the required WindowFrame; if current WindowFrame is null,
* the requiredFrame should be used as default frame.
*/
private void checkAndCompleteWindowFrame(WindowFrame requiredFrame, String functionName) {
windowExpression.getWindowFrame().ifPresent(wf -> {
if (!wf.equals(requiredFrame)) {
throw new AnalysisException("WindowFrame for " + functionName + "() must be null "
+ "or match with " + requiredFrame);
}
});
windowExpression = windowExpression.withWindowFrame(requiredFrame);
}
/* ********************************************************************************************
* methods for step 3
* ******************************************************************************************** */
private void checkWindowFrameAfterFunc(WindowFrame wf) {
if (wf.getRightBoundary().is(FrameBoundType.UNBOUNDED_FOLLOWING)
&& wf.getLeftBoundary().isNot(FrameBoundType.UNBOUNDED_PRECEDING)) {
// reverse OrderKey's asc and isNullFirst;
// in checkWindowFrameBeforeFunc(), we have confirmed that orderKeyLists must exist
List<OrderExpression> newOKList = windowExpression.getOrderKeys().stream()
.map(orderExpression -> {
OrderKey orderKey = orderExpression.getOrderKey();
return new OrderExpression(
new OrderKey(orderKey.getExpr(), !orderKey.isAsc(), !orderKey.isNullFirst()));
})
.collect(Collectors.toList());
windowExpression = windowExpression.withOrderKeys(newOKList);
// reverse WindowFrame
// e.g. (3 preceding, unbounded following) -> (unbounded preceding, 3 following)
windowExpression = windowExpression.withWindowFrame(wf.reverseWindow());
// reverse WindowFunction, which is used only for first_value() and last_value()
Expression windowFunction = windowExpression.getFunction();
if (windowFunction instanceof FirstOrLastValue) {
// windowExpression = windowExpression.withChildren(
// ImmutableList.of(((FirstOrLastValue) windowFunction).reverse()));
windowExpression = windowExpression.withFunction(((FirstOrLastValue) windowFunction).reverse());
}
}
}
private void setDefaultWindowFrameAfterFunc() {
// this is equal to DEFAULT_WINDOW in class AnalyticWindow
windowExpression = windowExpression.withWindowFrame(new WindowFrame(FrameUnitsType.RANGE,
FrameBoundary.newPrecedingBoundary(), FrameBoundary.newCurrentRowBoundary()));
}
}