CastExpr.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// This file is copied from
// https://github.com/apache/impala/blob/branch-2.9.0/fe/src/main/java/org/apache/impala/CastExpr.java
// and modified by Doris
package org.apache.doris.analysis;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.Function;
import org.apache.doris.catalog.Function.NullableMode;
import org.apache.doris.catalog.FunctionSet;
import org.apache.doris.catalog.PrimitiveType;
import org.apache.doris.catalog.ScalarFunction;
import org.apache.doris.catalog.ScalarType;
import org.apache.doris.catalog.Type;
import org.apache.doris.catalog.TypeUtils;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.common.FormatOptions;
import org.apache.doris.common.Pair;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.thrift.TExpr;
import org.apache.doris.thrift.TExprNode;
import org.apache.doris.thrift.TExprNodeType;
import org.apache.doris.thrift.TExprOpcode;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.gson.annotations.SerializedName;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.io.DataInput;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
public class CastExpr extends Expr {
private static final Logger LOG = LogManager.getLogger(CastExpr.class);
// Only set for explicit casts. Null for implicit casts.
@SerializedName("ttd")
private TypeDef targetTypeDef;
// True if this is a "pre-analyzed" implicit cast.
@SerializedName("ii")
private boolean isImplicit;
// True if this cast does not change the type.
private boolean noOp = false;
private boolean notFold = false;
private static final Map<Pair<Type, Type>, Function.NullableMode> TYPE_NULLABLE_MODE;
static {
TYPE_NULLABLE_MODE = Maps.newHashMap();
for (ScalarType fromType : Type.getSupportedTypes()) {
if (fromType.isNull()) {
continue;
}
for (ScalarType toType : Type.getSupportedTypes()) {
if (toType.isNull()) {
continue;
}
if (fromType.isStringType() && !toType.isStringType()) {
TYPE_NULLABLE_MODE.put(Pair.of(fromType, toType), Function.NullableMode.ALWAYS_NULLABLE);
} else if (!fromType.isDateType() && toType.isDateType()) {
TYPE_NULLABLE_MODE.put(Pair.of(fromType, toType), Function.NullableMode.ALWAYS_NULLABLE);
} else {
TYPE_NULLABLE_MODE.put(Pair.of(fromType, toType), Function.NullableMode.DEPEND_ON_ARGUMENT);
}
}
}
}
// only used restore from readFields.
private CastExpr() {
}
public CastExpr(Type targetType, Expr e) {
super();
Preconditions.checkArgument(targetType.isValid());
Preconditions.checkNotNull(e);
type = targetType;
targetTypeDef = null;
isImplicit = true;
children.add(e);
try {
analyze();
} catch (AnalysisException ex) {
LOG.warn("Implicit casts fail", ex);
Preconditions.checkState(false,
"Implicit casts should never throw analysis exception.");
}
analysisDone();
}
/**
* Just use for nereids, put analyze() in finalizeImplForNereids
*/
public CastExpr(Type targetType, Expr e, Void v) {
Preconditions.checkArgument(targetType.isValid());
Preconditions.checkNotNull(e, "cast child is null");
opcode = TExprOpcode.CAST;
type = targetType;
targetTypeDef = null;
isImplicit = true;
children.add(e);
noOp = Type.matchExactType(e.type, type, true);
if (noOp) {
// For decimalv2, we do not perform an actual cast between different precision/scale. Instead, we just
// set the target type as the child's type.
if (type.isDecimalV2() && e.type.isDecimalV2()) {
getChild(0).setType(type);
}
// as the targetType have struct field name, if use the default name will be
// like col1,col2, col3... in struct, and the filed name is import in BE.
if (type.isStructType() && e.type.isStructType()) {
getChild(0).setType(type);
}
if (type.isScalarType()) {
targetTypeDef = new TypeDef(type);
}
analysisDone();
return;
}
if (e.type.isNull()) {
analysisDone();
return;
}
// new function
if (type.isScalarType()) {
Type from = getActualArgTypes(collectChildReturnTypes())[0];
Type to = getActualType(type);
NullableMode nullableMode = TYPE_NULLABLE_MODE.get(Pair.of(from, to));
// for complex type cast to jsonb we make ret is always nullable
if (from.isComplexType() && type.isJsonbType()) {
nullableMode = Function.NullableMode.ALWAYS_NULLABLE;
}
Preconditions.checkState(nullableMode != null,
"cannot find nullable node for cast from " + from + " to " + to);
fn = new Function(new FunctionName(getFnName(type)), Lists.newArrayList(e.type), type,
false, true, nullableMode);
} else {
createComplexTypeCastFunction();
}
analysisDone();
}
/**
* Copy c'tor used in clone().
*/
public CastExpr(TypeDef targetTypeDef, Expr e) {
Preconditions.checkNotNull(targetTypeDef);
Preconditions.checkNotNull(e);
this.targetTypeDef = targetTypeDef;
isImplicit = false;
children.add(e);
}
protected CastExpr(CastExpr other) {
super(other);
targetTypeDef = other.targetTypeDef;
isImplicit = other.isImplicit;
noOp = other.noOp;
nullableFromNereids = other.nullableFromNereids;
}
private static String getFnName(Type targetType) {
return "castTo" + targetType.getPrimitiveType().toString();
}
public TypeDef getTargetTypeDef() {
return targetTypeDef;
}
public static void initBuiltins(FunctionSet functionSet) {
for (Type fromType : Type.getTrivialTypes()) {
if (fromType.isNull()) {
continue;
}
for (Type toType : Type.getTrivialTypes()) {
functionSet.addBuiltinBothScalaAndVectorized(ScalarFunction.createBuiltin(getFnName(toType),
toType, TYPE_NULLABLE_MODE.get(Pair.of(fromType, toType)),
Lists.newArrayList(fromType), false,
null, null, null, true));
}
}
}
@Override
public Expr clone() {
return new CastExpr(this);
}
@Override
public String toSqlImpl() {
if (needExternalSql) {
return getChild(0).toSql();
}
if (isAnalyzed) {
return "CAST(" + getChild(0).toSql() + " AS " + type.toSql() + ")";
} else {
return "CAST(" + getChild(0).toSql() + " AS "
+ (isImplicit ? type.toString() : targetTypeDef.toSql()) + ")";
}
}
@Override
public String toDigestImpl() {
boolean isVerbose = ConnectContext.get() != null
&& ConnectContext.get().getExecutor() != null
&& ConnectContext.get().getExecutor().getParsedStmt() != null
&& ConnectContext.get().getExecutor().getParsedStmt().getExplainOptions() != null
&& ConnectContext.get().getExecutor().getParsedStmt().getExplainOptions().isVerbose();
if (isImplicit && !isVerbose) {
return getChild(0).toDigest();
}
if (isAnalyzed) {
return "CAST(" + getChild(0).toDigest() + " AS " + type.toString() + ")";
} else {
return "CAST(" + getChild(0).toDigest() + " AS " + targetTypeDef.toString() + ")";
}
}
@Override
protected void treeToThriftHelper(TExpr container) {
if (noOp) {
getChild(0).treeToThriftHelper(container);
return;
}
super.treeToThriftHelper(container);
}
@Override
protected void toThrift(TExprNode msg) {
msg.node_type = TExprNodeType.CAST_EXPR;
msg.setOpcode(opcode);
if (type.isNativeType() && getChild(0).getType().isNativeType()) {
msg.setChildType(getChild(0).getType().getPrimitiveType().toThrift());
}
}
public boolean isImplicit() {
return isImplicit;
}
public void setImplicit(boolean implicit) {
isImplicit = implicit;
}
private void createComplexTypeCastFunction() {
if (type.isArrayType()) {
fn = ScalarFunction.createBuiltin(getFnName(Type.ARRAY),
type, Function.NullableMode.ALWAYS_NULLABLE,
Lists.newArrayList(getActualArgTypes(collectChildReturnTypes())[0]), false,
"doris::CastFunctions::cast_to_array_val", null, null, true);
} else if (type.isMapType()) {
fn = ScalarFunction.createBuiltin(getFnName(Type.MAP),
type, Function.NullableMode.ALWAYS_NULLABLE,
Lists.newArrayList(getActualArgTypes(collectChildReturnTypes())[0]), false,
"doris::CastFunctions::cast_to_map_val", null, null, true);
} else if (type.isStructType()) {
fn = ScalarFunction.createBuiltin(getFnName(Type.STRUCT),
type, Function.NullableMode.ALWAYS_NULLABLE,
Lists.newArrayList(Type.VARCHAR), false,
"doris::CastFunctions::cast_to_struct_val", null, null, true);
}
}
public void analyze() throws AnalysisException {
// do not analyze ALL cast
if (type == Type.ALL) {
return;
}
// cast was asked for in the query, check for validity of cast
Type childType = getChild(0).getType();
// this cast may result in loss of precision, but the user requested it
noOp = Type.matchExactType(childType, type, true);
if (noOp) {
// For decimalv2, we do not perform an actual cast between different precision/scale. Instead, we just
// set the target type as the child's type.
if (type.isDecimalV2() && childType.isDecimalV2()) {
getChild(0).setType(type);
}
return;
}
// select stmt will make BE coredump when its castExpr is like cast(int as array<>),
// it is necessary to check if it is castable before creating fn.
// char type will fail in canCastTo, so for compatibility, only the cast of array type is checked here.
if (type.isArrayType() || childType.isArrayType()) {
if (!Type.canCastTo(childType, type)) {
throw new AnalysisException("Invalid type cast of " + getChild(0).toSql()
+ " from " + childType + " to " + type);
}
}
this.opcode = TExprOpcode.CAST;
FunctionName fnName = new FunctionName(getFnName(type));
Function searchDesc = new Function(fnName, Arrays.asList(getActualArgTypes(collectChildReturnTypes())),
Type.INVALID, false);
if (type.isScalarType()) {
fn = Env.getCurrentEnv().getFunction(searchDesc, Function.CompareMode.IS_IDENTICAL);
} else {
createComplexTypeCastFunction();
}
if (fn == null) {
//TODO(xy): check map type
if ((type.isMapType() || type.isStructType()) && childType.isStringType()) {
return;
}
// same with Type.canCastTo() can be cast to jsonb
if (childType.isComplexType() && type.isJsonbType()) {
return;
}
if (childType.isNull() && Type.canCastTo(childType, type)) {
return;
} else {
throw new AnalysisException("Invalid type cast of " + getChild(0).toSql()
+ " from " + childType + " to " + type);
}
}
if (PrimitiveType.typeWithPrecision.contains(type.getPrimitiveType())) {
Preconditions.checkState(type.isDecimalV3() == fn.getReturnType().isDecimalV3()
|| type.isDatetimeV2() == fn.getReturnType().isDatetimeV2(),
type + " != " + fn.getReturnType());
} else {
Preconditions.checkState(type.matchesType(fn.getReturnType()), type + " != " + fn.getReturnType());
}
}
@Override
public void analyzeImpl(Analyzer analyzer) throws AnalysisException {
if (isImplicit) {
return;
}
// When cast target type is string and it's length is default -1, the result length
// of cast is decided by child.
if (targetTypeDef.getType().isScalarType()) {
final ScalarType targetType = (ScalarType) targetTypeDef.getType();
if (!(targetType.getPrimitiveType().isStringType() && !targetType.isLengthSet())) {
targetTypeDef.analyze(analyzer);
}
} else {
targetTypeDef.analyze(analyzer);
}
type = targetTypeDef.getType();
analyze();
}
@Override
public int hashCode() {
return super.hashCode();
}
@Override
public boolean equals(Object obj) {
if (!super.equals(obj)) {
return false;
}
CastExpr expr = (CastExpr) obj;
return this.opcode == expr.opcode;
}
/**
* Returns child expr if this expr is an implicit cast, otherwise returns 'this'.
*/
@Override
public Expr ignoreImplicitCast() {
if (isImplicit) {
// we don't expect to see to consecutive implicit casts
Preconditions.checkState(!(getChild(0) instanceof CastExpr)
|| !((CastExpr) getChild(0)).isImplicit());
return getChild(0);
} else {
return this;
}
}
public boolean canHashPartition() {
if (type.isFixedPointType() && getChild(0).getType().isFixedPointType()) {
return true;
}
if (type.isDateType() && getChild(0).getType().isDateType()) {
return true;
}
return false;
}
@Override
public Expr getResultValue(boolean forPushDownPredicatesToView) throws AnalysisException {
recursiveResetChildrenResult(forPushDownPredicatesToView);
final Expr value = children.get(0);
if (!(value instanceof LiteralExpr)) {
return this;
}
Expr targetExpr;
try {
targetExpr = castTo((LiteralExpr) value);
if (targetTypeDef != null) {
targetExpr.setType(targetTypeDef.getType());
} else {
targetExpr.setType(type);
}
} catch (AnalysisException ae) {
if (ConnectContext.get() != null) {
ConnectContext.get().getState().reset();
}
targetExpr = this;
} catch (NumberFormatException nfe) {
targetExpr = new NullLiteral();
}
return targetExpr;
}
private Expr castTo(LiteralExpr value) throws AnalysisException {
if (value instanceof NullLiteral) {
if (targetTypeDef != null) {
return NullLiteral.create(targetTypeDef.getType());
} else {
return NullLiteral.create(type);
}
} else if (type.isIntegerType()) {
return new IntLiteral(value.getLongValue(), type);
} else if (type.isLargeIntType()) {
return new LargeIntLiteral(value.getStringValue());
} else if (type.isDecimalV2() || type.isDecimalV3()) {
if (targetTypeDef != null) {
DecimalLiteral literal = new DecimalLiteral(value.getStringValue(),
((ScalarType) targetTypeDef.getType()).getScalarScale());
literal.checkPrecisionAndScale(targetTypeDef.getType().getPrecision(),
((ScalarType) targetTypeDef.getType()).getScalarScale());
return literal;
} else {
return new DecimalLiteral(value.getStringValue());
}
} else if (type.isFloatingPointType()) {
return new FloatLiteral(value.getDoubleValue(), type);
} else if (type.isStringType()) {
return new StringLiteral(value.getStringValue());
} else if (type.isDateType()) {
return new StringLiteral(value.getStringValue()).convertToDate(type);
} else if (type.isBoolean()) {
return new BoolLiteral(value.getStringValue());
}
return this;
}
public static CastExpr read(DataInput input) throws IOException {
CastExpr castExpr = new CastExpr();
castExpr.readFields(input);
return castExpr;
}
@Override
public void readFields(DataInput in) throws IOException {
isImplicit = in.readBoolean();
ScalarType scalarType = TypeUtils.readScalaType(in);
targetTypeDef = new TypeDef(scalarType);
int counter = in.readInt();
for (int i = 0; i < counter; i++) {
children.add(Expr.readIn(in));
}
}
public CastExpr rewriteExpr(List<String> parameters, List<Expr> inputParamsExprs) throws AnalysisException {
// child
Expr child = this.getChild(0);
Expr newChild = null;
if (child instanceof SlotRef) {
String columnName = ((SlotRef) child).getColumnName();
int index = parameters.indexOf(columnName);
if (index != -1) {
newChild = inputParamsExprs.get(index);
}
}
// rewrite cast expr in children
if (child instanceof CastExpr) {
newChild = ((CastExpr) child).rewriteExpr(parameters, inputParamsExprs);
}
// type def
ScalarType targetType = (ScalarType) targetTypeDef.getType();
PrimitiveType primitiveType = targetType.getPrimitiveType();
ScalarType newTargetType = null;
switch (primitiveType) {
case DECIMALV2:
case DECIMAL32:
case DECIMAL64:
case DECIMAL128:
case DECIMAL256:
// normal decimal
if (targetType.getPrecision() != 0) {
newTargetType = targetType;
break;
}
int precision = getDigital(targetType.getScalarPrecisionStr(), parameters, inputParamsExprs);
int scale = getDigital(targetType.getScalarScaleStr(), parameters, inputParamsExprs);
if (precision != -1 && scale != -1) {
newTargetType = ScalarType.createType(primitiveType, 0, precision, scale);
} else if (precision != -1) {
newTargetType = ScalarType.createType(primitiveType, 0, precision, ScalarType.DEFAULT_SCALE);
}
break;
case CHAR:
case VARCHAR:
// normal char/varchar
if (targetType.getLength() != -1) {
newTargetType = targetType;
break;
}
int len = getDigital(targetType.getLenStr(), parameters, inputParamsExprs);
if (len != -1) {
newTargetType = ScalarType.createType(primitiveType, len, 0, 0);
}
// default char/varchar, which len is -1
if (len == -1 && targetType.getLength() == -1) {
newTargetType = targetType;
}
break;
default:
newTargetType = targetType;
break;
}
if (newTargetType != null && newChild != null) {
TypeDef typeDef = new TypeDef(newTargetType);
return new CastExpr(typeDef, newChild);
}
return this;
}
private int getDigital(String desc, List<String> parameters, List<Expr> inputParamsExprs) {
int index = parameters.indexOf(desc);
if (index != -1) {
Expr expr = inputParamsExprs.get(index);
if (expr.getType().isIntegerType()) {
return ((Long) ((IntLiteral) expr).getRealValue()).intValue();
}
}
return -1;
}
@Override
public boolean isNullable() {
return children.get(0).isNullable()
|| (children.get(0).getType().isStringType() && !getType().isStringType())
|| (!children.get(0).getType().isDateType() && getType().isDateType());
}
@Override
public String getStringValueForStreamLoad(FormatOptions options) {
return children.get(0).getStringValueForStreamLoad(options);
}
@Override
protected String getStringValueInComplexTypeForQuery(FormatOptions options) {
return children.get(0).getStringValueInComplexTypeForQuery(options);
}
public void setNotFold(boolean notFold) {
this.notFold = notFold;
}
public boolean isNotFold() {
return this.notFold;
}
@Override
protected void compactForLiteral(Type type) {
// do nothing
}
}