JsonFunctionRewrite.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rules;
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.rules.expression.ExpressionRuleType;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonArray;
import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonArrayIgnoreNull;
import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonInsert;
import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonObject;
import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonReplace;
import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonSet;
import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonUnQuote;
import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonbExtract;
import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonbExtractBigint;
import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonbExtractBool;
import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonbExtractDouble;
import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonbExtractInt;
import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonbExtractLargeint;
import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonbExtractString;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ScalarFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ToJson;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.JsonType;
import org.apache.doris.nereids.types.LargeIntType;
import org.apache.doris.nereids.types.StringType;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
/**
* JsonFunctionRewrite
* `JsonArray(col1, col2, col3)` => `JsonArray(ToJson(col1), ToJson(col2), ToJson(col3))`
*/
public class JsonFunctionRewrite implements ExpressionPatternRuleFactory {
public static JsonFunctionRewrite INSTANCE = new JsonFunctionRewrite();
@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
matchesType(JsonArray.class).then(JsonFunctionRewrite::rewriteJsonArrayArguments)
.toRule(ExpressionRuleType.JSON_FUNCTION_REWRITE_JSON_ARRAY),
matchesType(JsonArrayIgnoreNull.class).then(JsonFunctionRewrite::rewriteJsonArrayArguments)
.toRule(ExpressionRuleType.JSON_FUNCTION_REWRITE_JSON_ARRAY_IGNORE_NULL),
matchesType(JsonObject.class).then(JsonFunctionRewrite::rewriteJsonObjectArguments)
.toRule(ExpressionRuleType.JSON_FUNCTION_REWRITE_JSON_OBJECT),
matchesType(JsonInsert.class).then(JsonFunctionRewrite::rewriteJsonModifyArguments)
.toRule(ExpressionRuleType.JSON_FUNCTION_REWRITE_JSON_INSERT),
matchesType(JsonSet.class).then(JsonFunctionRewrite::rewriteJsonModifyArguments)
.toRule(ExpressionRuleType.JSON_FUNCTION_REWRITE_JSON_SET),
matchesType(JsonReplace.class).then(JsonFunctionRewrite::rewriteJsonModifyArguments)
.toRule(ExpressionRuleType.JSON_FUNCTION_REWRITE_JSON_REPLACE),
matchesType(JsonbExtractInt.class).then(JsonFunctionRewrite::rewriteJsonExtractFunctions)
.toRule(ExpressionRuleType.JSON_FUNCTION_REWRITE_JSON_EXTRACT_INT),
matchesType(JsonbExtractBigint.class).then(JsonFunctionRewrite::rewriteJsonExtractFunctions)
.toRule(ExpressionRuleType.JSON_FUNCTION_REWRITE_JSON_EXTRACT_BIGINT),
matchesType(JsonbExtractLargeint.class).then(JsonFunctionRewrite::rewriteJsonExtractFunctions)
.toRule(ExpressionRuleType.JSON_FUNCTION_REWRITE_JSON_EXTRACT_LARGEINT),
matchesType(JsonbExtractBool.class).then(JsonFunctionRewrite::rewriteJsonExtractFunctions)
.toRule(ExpressionRuleType.JSON_FUNCTION_REWRITE_JSON_EXTRACT_BOOLEAN),
matchesType(JsonbExtractDouble.class).then(JsonFunctionRewrite::rewriteJsonExtractFunctions)
.toRule(ExpressionRuleType.JSON_FUNCTION_REWRITE_JSON_EXTRACT_DOUBLE),
matchesType(JsonbExtractString.class).then(JsonFunctionRewrite::rewriteJsonExtractFunctions)
.toRule(ExpressionRuleType.JSON_FUNCTION_REWRITE_JSON_EXTRACT_STRING)
);
}
private static <T extends ScalarFunction> Expression rewriteJsonArrayArguments(T function) {
List<Expression> convectedChildren = new ArrayList<Expression>();
function.children().forEach(child -> {
if (child.getDataType() instanceof JsonType) {
convectedChildren.add(child);
} else {
convectedChildren.add(new ToJson(child));
}
});
return function.withChildren(convectedChildren);
}
private static Expression rewriteJsonObjectArguments(JsonObject function) {
List<Expression> convectedChildren = new ArrayList<Expression>();
List<Expression> children = function.children();
for (int i = 0; i < children.size(); i++) {
Expression child = children.get(i);
if (i % 2 == 0) {
convectedChildren.add(child);
} else if (child.getDataType() instanceof JsonType) {
convectedChildren.add(child);
} else {
convectedChildren.add(new ToJson(child));
}
}
return function.withChildren(convectedChildren);
}
private static <T extends ScalarFunction> Expression rewriteJsonModifyArguments(T function) {
List<Expression> convectedChildren = new ArrayList<Expression>();
List<Expression> children = function.children();
convectedChildren.add(children.get(0));
for (int i = 1; i < children.size(); i++) {
Expression child = children.get(i);
if (i % 2 == 1) {
convectedChildren.add(child);
} else if (child.getDataType() instanceof JsonType) {
convectedChildren.add(child);
} else {
convectedChildren.add(new ToJson(child));
}
}
return function.withChildren(convectedChildren);
}
private static <T extends ScalarFunction> Expression rewriteJsonExtractFunctions(T function) {
JsonbExtract jsonExtract = new JsonbExtract(function.children().get(0), function.children().get(1));
if (function instanceof JsonbExtractInt) {
return new Cast(jsonExtract, IntegerType.INSTANCE, false);
} else if (function instanceof JsonbExtractBigint) {
return new Cast(jsonExtract, BigIntType.INSTANCE, false);
} else if (function instanceof JsonbExtractLargeint) {
return new Cast(jsonExtract, LargeIntType.INSTANCE, false);
} else if (function instanceof JsonbExtractBool) {
return new Cast(jsonExtract, BooleanType.INSTANCE, false);
} else if (function instanceof JsonbExtractDouble) {
return new Cast(jsonExtract, DoubleType.INSTANCE, false);
} else if (function instanceof JsonbExtractString) {
return new JsonUnQuote(new Cast(jsonExtract, StringType.INSTANCE, false));
} else {
return function;
}
}
}