CaseWhen.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions;

import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;

import com.google.common.base.Preconditions;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Supplier;

/**
 * The internal representation of
 * CASE [expr] WHEN expr THEN expr [WHEN expr THEN expr ...] [ELSE expr] END
 * Each When/Then is stored as two consecutive children (whenExpr, thenExpr).
 * If a case expr is given, convert it to equalTo(caseExpr, whenExpr) and set it to whenExpr.
 * If an else expr is given then it is the last child.
 */
public class CaseWhen extends Expression {

    private final List<WhenClause> whenClauses;
    private final Optional<Expression> defaultValue;
    private Supplier<List<DataType>> dataTypesForCoercion;

    public CaseWhen(List<WhenClause> whenClauses) {
        super((List) whenClauses);
        this.whenClauses = ImmutableList.copyOf(Objects.requireNonNull(whenClauses));
        defaultValue = Optional.empty();
        this.dataTypesForCoercion = computeDataTypesForCoercion();
    }

    /** CaseWhen */
    public CaseWhen(List<WhenClause> whenClauses, Expression defaultValue) {
        super(ImmutableList.<Expression>builderWithExpectedSize(whenClauses.size() + 1)
                .addAll(whenClauses)
                .add(defaultValue)
                .build());
        this.whenClauses = ImmutableList.copyOf(Objects.requireNonNull(whenClauses));
        this.defaultValue = Optional.of(Objects.requireNonNull(defaultValue));
        this.dataTypesForCoercion = computeDataTypesForCoercion();
    }

    public List<WhenClause> getWhenClauses() {
        return whenClauses;
    }

    public Optional<Expression> getDefaultValue() {
        return defaultValue;
    }

    /** dataTypesForCoercion */
    public List<DataType> dataTypesForCoercion() {
        return this.dataTypesForCoercion.get();
    }

    public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
        return visitor.visitCaseWhen(this, context);
    }

    @Override
    public DataType getDataType() {
        return child(0).getDataType();
    }

    @Override
    public boolean nullable() {
        for (WhenClause whenClause : whenClauses) {
            if (whenClause.nullable()) {
                return true;
            }
        }
        return defaultValue.map(Expression::nullable).orElse(true);
    }

    @Override
    public String toString() {
        StringBuilder output = new StringBuilder("CASE");
        for (Expression child : children()) {
            if (child instanceof WhenClause) {
                output.append(child.toString());
            } else {
                output.append(" ELSE ").append(child.toString());
            }
        }
        output.append(" END");
        return output.toString();
    }

    @Override
    public String computeToSql() throws UnboundException {
        StringBuilder output = new StringBuilder("CASE");
        for (Expression child : children()) {
            if (child instanceof WhenClause) {
                output.append(child.toSql());
            } else {
                output.append(" ELSE ").append(child.toSql());
            }
        }
        output.append(" END");
        return output.toString();
    }

    @Override
    public CaseWhen withChildren(List<Expression> children) {
        Preconditions.checkArgument(!children.isEmpty(), "case when should has at least 1 child");
        List<WhenClause> whenClauseList = new ArrayList<>();
        Expression defaultValue = null;
        for (int i = 0; i < children.size(); i++) {
            if (children.get(i) instanceof WhenClause) {
                whenClauseList.add((WhenClause) children.get(i));
            } else if (children.size() - 1 == i) {
                defaultValue = children.get(i);
            } else {
                throw new AnalysisException("The children format needs to be [WhenClause+, DefaultValue?]");
            }
        }
        if (defaultValue == null) {
            return new CaseWhen(whenClauseList);
        }
        return new CaseWhen(whenClauseList, defaultValue);
    }

    private Supplier<List<DataType>> computeDataTypesForCoercion() {
        return Suppliers.memoize(() -> {
            Builder<DataType> dataTypes = ImmutableList.builderWithExpectedSize(
                    whenClauses.size() + (defaultValue.isPresent() ? 1 : 0));
            for (WhenClause whenClause : whenClauses) {
                dataTypes.add(whenClause.getDataType());
            }
            defaultValue.ifPresent(expression -> dataTypes.add(expression.getDataType()));
            return dataTypes.build();
        });
    }
}