FlightSqlSchemaHelper.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.service.arrowflight;

import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.PrimitiveType;
import org.apache.doris.datasource.CatalogIf;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.service.ExecuteEnv;
import org.apache.doris.service.FrontendServiceImpl;
import org.apache.doris.thrift.TColumnDef;
import org.apache.doris.thrift.TColumnDesc;
import org.apache.doris.thrift.TDescribeTablesParams;
import org.apache.doris.thrift.TDescribeTablesResult;
import org.apache.doris.thrift.TGetDbsParams;
import org.apache.doris.thrift.TGetDbsResult;
import org.apache.doris.thrift.TGetTablesParams;
import org.apache.doris.thrift.TListTableStatusResult;
import org.apache.doris.thrift.TTableStatus;

import org.apache.arrow.adapter.jdbc.JdbcToArrowUtils;
import org.apache.arrow.flight.sql.FlightSqlColumnMetadata;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetDbSchemas;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTables;
import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ZeroVector;
import org.apache.arrow.vector.complex.BaseRepeatedValueVector;
import org.apache.arrow.vector.complex.MapVector;
import org.apache.arrow.vector.ipc.WriteChannel;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.DateUnit;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.TimeUnit;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.Text;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.thrift.TException;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.channels.Channels;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class FlightSqlSchemaHelper {
    private static final Logger LOG = LogManager.getLogger(FlightSqlSchemaHelper.class);
    private final ConnectContext ctx;
    private final FrontendServiceImpl impl;
    private boolean includeSchema;
    private String catalogFilterPattern = null;
    private String dbSchemaFilterPattern = null;
    private String tableNameFilterPattern = null;
    private List<String> tableTypesList = null;

    public FlightSqlSchemaHelper(ConnectContext context) {
        ctx = context;
        impl = new FrontendServiceImpl(ExecuteEnv.getInstance());
    }

    private static final byte[] EMPTY_SERIALIZED_SCHEMA = getSerializedSchema(Collections.emptyList());

    /**
     * Convert Doris data type to an arrowType.
     * <p>
     * Ref: `convert_to_arrow_type` in be/src/util/arrow/row_batch.cpp.
     * which is consistent with the type of Arrow data returned by Doris Arrow Flight Sql query.
     */
    private static ArrowType getArrowType(PrimitiveType primitiveType, Integer precision, Integer scale,
            String timeZone) {
        switch (primitiveType) {
            case BOOLEAN:
                return new ArrowType.Bool();
            case TINYINT:
                return new ArrowType.Int(8, true);
            case SMALLINT:
                return new ArrowType.Int(16, true);
            case INT:
            case IPV4:
                return new ArrowType.Int(32, true);
            case BIGINT:
                return new ArrowType.Int(64, true);
            case FLOAT:
                return new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE);
            case DOUBLE:
                return new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE);
            case LARGEINT:
            case VARCHAR:
            case STRING:
            case CHAR:
            case DATETIME:
            case DATE:
            case JSONB:
            case IPV6:
            case VARIANT:
                return new ArrowType.Utf8();
            case DATEV2:
                return new ArrowType.Date(DateUnit.MILLISECOND);
            case DATETIMEV2:
                if (scale > 3) {
                    return new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZone);
                } else if (scale > 0) {
                    return new ArrowType.Timestamp(TimeUnit.MILLISECOND, timeZone);
                } else {
                    return new ArrowType.Timestamp(TimeUnit.SECOND, timeZone);
                }
            case DECIMAL32:
            case DECIMAL64:
            case DECIMAL128:
                return new ArrowType.Decimal(precision, scale, 128);
            case DECIMAL256:
                return new ArrowType.Decimal(precision, scale, 256);
            case DECIMALV2:
                return new ArrowType.Decimal(27, 9, 128);
            case HLL:
            case BITMAP:
            case QUANTILE_STATE:
                return new ArrowType.Binary();
            case MAP:
                return new ArrowType.Map(false);
            case ARRAY:
                return new ArrowType.List();
            case STRUCT:
                return new ArrowType.Struct();
            default:
                return new ArrowType.Null();
        }
    }

    private static ArrowType columnDescToArrowType(final TColumnDesc desc) {
        PrimitiveType primitiveType = PrimitiveType.fromThrift(desc.getColumnType());
        Integer precision = desc.isSetColumnPrecision() ? desc.getColumnPrecision() : null;
        Integer scale = desc.isSetColumnScale() ? desc.getColumnScale() : null;
        // TODO there is no timezone in TColumnDesc, so use current timezone.
        String timeZone = JdbcToArrowUtils.getUtcCalendar().getTimeZone().getID();
        return getArrowType(primitiveType, precision, scale, timeZone);
    }

    private static Map<String, String> createFlightSqlColumnMetadata(final String dbName, final String tableName,
            final TColumnDesc desc) {
        final FlightSqlColumnMetadata.Builder columnMetadataBuilder = new FlightSqlColumnMetadata.Builder().schemaName(
                        dbName).tableName(tableName).typeName(PrimitiveType.fromThrift(desc.getColumnType()).toString())
                .isAutoIncrement(false).isCaseSensitive(false).isReadOnly(true).isSearchable(true);

        if (desc.isSetColumnPrecision()) {
            columnMetadataBuilder.precision(desc.getColumnPrecision());
        }
        if (desc.isSetColumnScale()) {
            columnMetadataBuilder.scale(desc.getColumnScale());
        }
        return columnMetadataBuilder.build().getMetadataMap();
    }

    protected static byte[] getSerializedSchema(List<Field> fields) {
        if (EMPTY_SERIALIZED_SCHEMA == null && fields == null) {
            fields = Collections.emptyList();
        } else if (fields == null) {
            return Arrays.copyOf(EMPTY_SERIALIZED_SCHEMA, EMPTY_SERIALIZED_SCHEMA.length);
        }

        final ByteArrayOutputStream columnOutputStream = new ByteArrayOutputStream();
        final Schema schema = new Schema(fields);

        try {
            MessageSerializer.serialize(new WriteChannel(Channels.newChannel(columnOutputStream)), schema);
        } catch (final IOException e) {
            throw new RuntimeException("IO Error when serializing schema '" + schema + "'.", e);
        }

        return columnOutputStream.toByteArray();
    }

    /**
     * Set in the Tables request object the parameter that user passed via CommandGetTables.
     */
    public void setParameterForGetTables(CommandGetTables command) {
        includeSchema = command.getIncludeSchema();
        catalogFilterPattern = command.hasCatalog() ? command.getCatalog() : "internal";
        dbSchemaFilterPattern = command.hasDbSchemaFilterPattern() ? command.getDbSchemaFilterPattern() : null;
        tableNameFilterPattern = command.hasTableNameFilterPattern() ? command.getTableNameFilterPattern() : null;
        tableTypesList = command.getTableTypesList().isEmpty() ? null : command.getTableTypesList();
    }

    /**
     * Set in the Schemas request object the parameter that user passed via CommandGetDbSchemas.
     */
    public void setParameterForGetDbSchemas(CommandGetDbSchemas command) {
        catalogFilterPattern = command.hasCatalog() ? command.getCatalog() : "internal";
        dbSchemaFilterPattern = command.hasDbSchemaFilterPattern() ? command.getDbSchemaFilterPattern() : null;
    }

    /**
     * Call FrontendServiceImpl->getDbNames.
     */
    private TGetDbsResult getDbNames() throws TException {
        TGetDbsParams getDbsParams = new TGetDbsParams();
        if (catalogFilterPattern != null) {
            getDbsParams.setCatalog(catalogFilterPattern);
        }
        if (dbSchemaFilterPattern != null) {
            getDbsParams.setPattern(dbSchemaFilterPattern);
        }
        getDbsParams.setCurrentUserIdent(ctx.getCurrentUserIdentity().toThrift());
        return impl.getDbNames(getDbsParams);
    }

    /**
     * Call FrontendServiceImpl->listTableStatus.
     */
    private TListTableStatusResult listTableStatus(String dbName, String catalogName) throws TException {
        TGetTablesParams getTablesParams = new TGetTablesParams();
        getTablesParams.setDb(dbName);
        if (!catalogName.isEmpty()) {
            getTablesParams.setCatalog(catalogName);
        }
        if (tableNameFilterPattern != null) {
            getTablesParams.setPattern(tableNameFilterPattern);
        }
        if (tableTypesList != null) {
            getTablesParams.setType(tableTypesList.get(0)); // currently only one type is supported.
        }
        getTablesParams.setCurrentUserIdent(ctx.getCurrentUserIdentity().toThrift());
        return impl.listTableStatus(getTablesParams);
    }

    /**
     * Call FrontendServiceImpl->describeTables.
     */
    private TDescribeTablesResult describeTables(String dbName, String catalogName, List<String> tablesName)
            throws TException {
        TDescribeTablesParams describeTablesParams = new TDescribeTablesParams();
        describeTablesParams.setDb(dbName);
        if (!catalogName.isEmpty()) {
            describeTablesParams.setCatalog(catalogName);
        }
        describeTablesParams.setTablesName(tablesName);
        describeTablesParams.setCurrentUserIdent(ctx.getCurrentUserIdentity().toThrift());
        return impl.describeTables(describeTablesParams);
    }

    /**
     * Construct <tableName, List<ArrowType>>
     */
    private Map<String, List<Field>> buildTableToFields(String dbName, TDescribeTablesResult describeTablesResult,
            List<String> tablesName) {
        Map<String, List<Field>> tableToFields = new HashMap<>();
        int columnIndex = 0;
        for (int tableIndex = 0; tableIndex < describeTablesResult.getTablesOffsetSize(); tableIndex++) {
            String tableName = tablesName.get(tableIndex);
            final List<Field> fields = new ArrayList<>();
            Integer tableOffset = describeTablesResult.getTablesOffset().get(tableIndex);
            for (; columnIndex < tableOffset; columnIndex++) {
                TColumnDef columnDef = describeTablesResult.getColumns().get(columnIndex);
                TColumnDesc columnDesc = columnDef.getColumnDesc();
                final ArrowType columnArrowType = columnDescToArrowType(columnDesc);

                List<Field> columnArrowTypeChildren;
                // Arrow complex types may require children fields for parsing the schema on C++
                switch (columnArrowType.getTypeID()) {
                    case List:
                    case LargeList:
                    case FixedSizeList:
                        columnArrowTypeChildren = Collections.singletonList(
                                Field.notNullable(BaseRepeatedValueVector.DATA_VECTOR_NAME,
                                        ZeroVector.INSTANCE.getField().getType()));
                        break;
                    case Map:
                        columnArrowTypeChildren = Collections.singletonList(
                                Field.notNullable(MapVector.DATA_VECTOR_NAME, new ArrowType.List()));
                        break;
                    case Struct:
                        columnArrowTypeChildren = Collections.emptyList();
                        break;
                    default:
                        columnArrowTypeChildren = null;
                        break;
                }

                final Field field = new Field(columnDesc.getColumnName(),
                        new FieldType(columnDesc.isIsAllowNull(), columnArrowType, null,
                                createFlightSqlColumnMetadata(dbName, tableName, columnDesc)), columnArrowTypeChildren);
                fields.add(field);
            }
            tableToFields.put(tableName, fields);
        }
        return tableToFields;
    }

    /**
     * for FlightSqlProducer Schemas.GET_CATALOGS_SCHEMA
     */
    public void getCatalogs(VectorSchemaRoot vectorSchemaRoot) throws TException {
        VarCharVector catalogNameVector = (VarCharVector) vectorSchemaRoot.getVector("catalog_name");

        Set<String> catalogsSet = new LinkedHashSet<>();
        catalogsSet.add("internal"); // An ordered Set with "internal" first.
        for (CatalogIf catalog : Env.getCurrentEnv().getCatalogMgr().listCatalogs()) {
            catalogsSet.add(catalog.getName());
        }

        int catalogIndex = 0;
        for (String catalog : catalogsSet) {
            catalogNameVector.setSafe(catalogIndex, new Text(catalog));
            catalogIndex++;
        }
        vectorSchemaRoot.setRowCount(catalogIndex);
    }

    /**
     * for FlightSqlProducer Schemas.GET_SCHEMAS_SCHEMA
     */
    public void getSchemas(VectorSchemaRoot vectorSchemaRoot) throws TException {
        VarCharVector catalogNameVector = (VarCharVector) vectorSchemaRoot.getVector("catalog_name");
        VarCharVector schemaNameVector = (VarCharVector) vectorSchemaRoot.getVector("db_schema_name");

        TGetDbsResult getDbsResult = getDbNames();
        for (int dbIndex = 0; dbIndex < getDbsResult.getDbs().size(); dbIndex++) {
            String dbName = getDbsResult.getDbs().get(dbIndex);
            String catalogName = getDbsResult.isSetCatalogs() ? getDbsResult.getCatalogs().get(dbIndex) : "";
            catalogNameVector.setSafe(dbIndex, new Text(catalogName));
            schemaNameVector.setSafe(dbIndex, new Text(dbName));
        }
        vectorSchemaRoot.setRowCount(getDbsResult.getDbs().size());
    }

    /**
     * for FlightSqlProducer Schemas.GET_TABLES_SCHEMA_NO_SCHEMA and Schemas.GET_TABLES_SCHEMA
     */
    public void getTables(VectorSchemaRoot vectorSchemaRoot) throws TException {
        VarCharVector catalogNameVector = (VarCharVector) vectorSchemaRoot.getVector("catalog_name");
        VarCharVector schemaNameVector = (VarCharVector) vectorSchemaRoot.getVector("db_schema_name");
        VarCharVector tableNameVector = (VarCharVector) vectorSchemaRoot.getVector("table_name");
        VarCharVector tableTypeVector = (VarCharVector) vectorSchemaRoot.getVector("table_type");
        VarBinaryVector schemaVector = (VarBinaryVector) vectorSchemaRoot.getVector("table_schema");

        int tablesCount = 0;
        TGetDbsResult getDbsResult = getDbNames();
        for (int dbIndex = 0; dbIndex < getDbsResult.getDbs().size(); dbIndex++) {
            String dbName = getDbsResult.getDbs().get(dbIndex);
            String catalogName = getDbsResult.isSetCatalogs() ? getDbsResult.getCatalogs().get(dbIndex) : "";
            TListTableStatusResult listTableStatusResult = listTableStatus(dbName, catalogName);

            Map<String, List<Field>> tableToFields;
            if (includeSchema) {
                List<String> tablesName = new ArrayList<>();
                for (TTableStatus tableStatus : listTableStatusResult.getTables()) {
                    tablesName.add(tableStatus.getName());
                }
                TDescribeTablesResult describeTablesResult = describeTables(dbName, catalogName, tablesName);
                tableToFields = buildTableToFields(dbName, describeTablesResult, tablesName);
            } else {
                tableToFields = null;
            }

            for (TTableStatus tableStatus : listTableStatusResult.getTables()) {
                catalogNameVector.setSafe(tablesCount, new Text(catalogName));
                schemaNameVector.setSafe(tablesCount, new Text(dbName));
                tableNameVector.setSafe(tablesCount, new Text(tableStatus.getName()));
                tableTypeVector.setSafe(tablesCount, new Text(tableStatus.getType()));
                if (includeSchema) {
                    List<Field> fields = tableToFields.get(tableStatus.getName());
                    schemaVector.setSafe(tablesCount, getSerializedSchema(fields));
                }
                tablesCount++;
            }
        }
        vectorSchemaRoot.setRowCount(tablesCount);
    }
}