DorisFlightSqlProducer.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.
// This file is copied from
// https://github.com/apache/arrow/blob/main/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java
// and modified by Doris

package org.apache.doris.service.arrowflight;

import org.apache.doris.common.util.DebugUtil;
import org.apache.doris.common.util.Util;
import org.apache.doris.mysql.MysqlCommand;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.QueryState.MysqlStateType;
import org.apache.doris.service.arrowflight.results.FlightSqlEndpointsLocation;
import org.apache.doris.service.arrowflight.results.FlightSqlResultCacheEntry;
import org.apache.doris.service.arrowflight.sessions.FlightSessionsManager;
import org.apache.doris.thrift.TUniqueId;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.Message;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.CloseSessionRequest;
import org.apache.arrow.flight.CloseSessionResult;
import org.apache.arrow.flight.Criteria;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightEndpoint;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.PutResult;
import org.apache.arrow.flight.Result;
import org.apache.arrow.flight.SchemaResult;
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.flight.sql.FlightSqlProducer;
import org.apache.arrow.flight.sql.SqlInfoBuilder;
import org.apache.arrow.flight.sql.impl.FlightSql.ActionClosePreparedStatementRequest;
import org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementRequest;
import org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementResult;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCatalogs;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCrossReference;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetDbSchemas;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetExportedKeys;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetImportedKeys;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetPrimaryKeys;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetSqlInfo;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTableTypes;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTables;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetXdbcTypeInfo;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementQuery;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementUpdate;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementQuery;
import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementUpdate;
import org.apache.arrow.flight.sql.impl.FlightSql.DoPutUpdateResult;
import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedCaseSensitivity;
import org.apache.arrow.flight.sql.impl.FlightSql.TicketStatementQuery;
import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.WriteChannel;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

/**
 * Implementation of Arrow Flight SQL service
 * <p>
 * All methods must catch all possible Exceptions, print and throw CallStatus,
 * otherwise error message will be discarded.
 */
public class DorisFlightSqlProducer implements FlightSqlProducer, AutoCloseable {
    private static final Logger LOG = LogManager.getLogger(DorisFlightSqlProducer.class);
    private final Location location;
    private final BufferAllocator rootAllocator = new RootAllocator();
    private final SqlInfoBuilder sqlInfoBuilder;
    private final FlightSessionsManager flightSessionsManager;
    private final ExecutorService executorService = Executors.newFixedThreadPool(100);

    public DorisFlightSqlProducer(final Location location, FlightSessionsManager flightSessionsManager) {
        this.location = location;
        this.flightSessionsManager = flightSessionsManager;
        sqlInfoBuilder = new SqlInfoBuilder();
        sqlInfoBuilder.withFlightSqlServerName("DorisFE").withFlightSqlServerVersion("1.0")
                .withFlightSqlServerArrowVersion("18.2.0").withFlightSqlServerReadOnly(false)
                .withSqlIdentifierQuoteChar("`").withSqlDdlCatalog(true).withSqlDdlSchema(false).withSqlDdlTable(false)
                .withSqlIdentifierCase(SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_CASE_INSENSITIVE)
                .withSqlQuotedIdentifierCase(SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_CASE_INSENSITIVE);
    }

    private static ByteBuffer serializeMetadata(final Schema schema) {
        final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
        try {
            MessageSerializer.serialize(new WriteChannel(Channels.newChannel(outputStream)), schema);

            return ByteBuffer.wrap(outputStream.toByteArray());
        } catch (final IOException e) {
            throw new RuntimeException("Failed to serialize arrow flight sql schema", e);
        }
    }

    private void getStreamStatementResult(String handle, ServerStreamListener listener) {
        String[] handleParts = handle.split(":");
        String executedPeerIdentity = handleParts[0];
        String queryId = handleParts[1];
        // The tokens used for authentication between getStreamStatement and getFlightInfoStatement are different.
        ConnectContext connectContext = flightSessionsManager.getConnectContext(executedPeerIdentity);
        try {
            final FlightSqlResultCacheEntry flightSqlResultCacheEntry = Objects.requireNonNull(
                    connectContext.getFlightSqlChannel().getResult(queryId));
            final VectorSchemaRoot vectorSchemaRoot = flightSqlResultCacheEntry.getVectorSchemaRoot();
            listener.start(vectorSchemaRoot);
            listener.putNext();
        } catch (Throwable e) {
            String errMsg = "get stream statement failed, " + e.getMessage() + ", " + Util.getRootCauseMessage(e)
                    + ", error code: " + connectContext.getState().getErrorCode() + ", error msg: "
                    + connectContext.getState().getErrorMessage();
            handleStreamException(e, errMsg, listener);
        } finally {
            listener.completed();
            // The result has been sent or sent failed, delete it.
            connectContext.getFlightSqlChannel().invalidate(queryId);
        }
    }

    @Override
    public void getStreamPreparedStatement(final CommandPreparedStatementQuery command, final CallContext context,
            final ServerStreamListener listener) {
        getStreamStatementResult(command.getPreparedStatementHandle().toStringUtf8(), listener);
    }

    @Override
    public void getStreamStatement(final TicketStatementQuery ticketStatementQuery, final CallContext context,
            final ServerStreamListener listener) {
        getStreamStatementResult(ticketStatementQuery.getStatementHandle().toStringUtf8(), listener);
    }

    @Override
    public void closePreparedStatement(final ActionClosePreparedStatementRequest request, final CallContext context,
            final StreamListener<Result> listener) {
        executorService.submit(() -> {
            try {
                String[] handleParts = request.getPreparedStatementHandle().toStringUtf8().split(":");
                String executedPeerIdentity = handleParts[0];
                String preparedStatementId = handleParts[1];
                flightSessionsManager.getConnectContext(executedPeerIdentity).removePreparedQuery(preparedStatementId);
            } catch (final Throwable e) {
                listener.onError(e);
                return;
            }
            listener.onCompleted();
        });
    }

    private FlightInfo executeQueryStatement(String peerIdentity, ConnectContext connectContext, String query,
            final FlightDescriptor descriptor) {
        try {
            Preconditions.checkState(null != connectContext);
            Preconditions.checkState(!query.isEmpty());
            // After the previous query was executed, there was no getStreamStatement to take away the result.
            connectContext.getFlightSqlChannel().reset();
            connectContext.clearFlightSqlEndpointsLocations();
            try (FlightSqlConnectProcessor flightSQLConnectProcessor = new FlightSqlConnectProcessor(connectContext)) {
                flightSQLConnectProcessor.handleQuery(query);
                if (connectContext.getState().getStateType() == MysqlStateType.ERR) {
                    throw new RuntimeException("after executeQueryStatement handleQuery");
                }

                if (connectContext.isReturnResultFromLocal()) {
                    // set/use etc. stmt returns an OK result by default.
                    if (connectContext.getFlightSqlChannel().resultNum() == 0) {
                        // a random query id and add empty results
                        String queryId = UUID.randomUUID().toString();
                        connectContext.getFlightSqlChannel().addOKResult(queryId, query);

                        final ByteString handle = ByteString.copyFromUtf8(peerIdentity + ":" + queryId);
                        TicketStatementQuery ticketStatement = TicketStatementQuery.newBuilder()
                                .setStatementHandle(handle).build();
                        return getFlightInfoForSchema(ticketStatement, descriptor,
                                connectContext.getFlightSqlChannel().getResult(queryId).getVectorSchemaRoot()
                                        .getSchema());
                    } else {
                        // A Flight Sql request can only contain one statement that returns result,
                        // otherwise expected thrown exception during execution.
                        Preconditions.checkState(connectContext.getFlightSqlChannel().resultNum() == 1);

                        // The tokens used for authentication between getStreamStatement and getFlightInfoStatement
                        // are different. So put the peerIdentity into the ticket and then getStreamStatement is used to
                        // find the correct ConnectContext.
                        // queryId is used to find query results.
                        final ByteString handle = ByteString.copyFromUtf8(
                                peerIdentity + ":" + DebugUtil.printId(connectContext.queryId()));
                        TicketStatementQuery ticketStatement = TicketStatementQuery.newBuilder()
                                .setStatementHandle(handle).build();
                        return getFlightInfoForSchema(ticketStatement, descriptor, connectContext.getFlightSqlChannel()
                                .getResult(DebugUtil.printId(connectContext.queryId())).getVectorSchemaRoot()
                                .getSchema());
                    }
                } else {
                    // Now only query stmt will pull results from BE.
                    flightSQLConnectProcessor.fetchArrowFlightSchema(5000);
                    if (flightSQLConnectProcessor.getArrowSchema() == null) {
                        throw CallStatus.INTERNAL.withDescription("fetch arrow flight schema is null")
                                .toRuntimeException();
                    }

                    List<FlightEndpoint> endpoints = Lists.newArrayList();
                    for (FlightSqlEndpointsLocation endpointLoc : connectContext.getFlightSqlEndpointsLocations()) {
                        TUniqueId tid = endpointLoc.getFinstId();
                        // Ticket contains the IP and Brpc Port of the Doris BE node where the query result is located.
                        final ByteString handle = ByteString.copyFromUtf8(
                                DebugUtil.printId(tid) + "&" + endpointLoc.getResultInternalServiceAddr().hostname + "&"
                                        + endpointLoc.getResultInternalServiceAddr().port + "&" + query);
                        TicketStatementQuery ticketStatement = TicketStatementQuery.newBuilder()
                                .setStatementHandle(handle).build();
                        Ticket ticket = new Ticket(Any.pack(ticketStatement).toByteArray());
                        Location location;
                        if (endpointLoc.getResultPublicAccessAddr().isSetHostname()) {
                            // In a production environment, it is often inconvenient to expose Doris BE nodes
                            // to the external network.
                            // However, a reverse proxy (such as nginx) can be added to all Doris BE nodes,
                            // and the external client will be randomly routed to a Doris BE node when connecting
                            // to nginx.
                            // The query results of Arrow Flight SQL will be randomly saved on a Doris BE node.
                            // If it is different from the Doris BE node randomly routed by nginx,
                            // data forwarding needs to be done inside the Doris BE node.
                            if (endpointLoc.getResultPublicAccessAddr().isSetPort()) {
                                location = Location.forGrpcInsecure(endpointLoc.getResultPublicAccessAddr().hostname,
                                        endpointLoc.getResultPublicAccessAddr().port);
                            } else {
                                location = Location.forGrpcInsecure(endpointLoc.getResultPublicAccessAddr().hostname,
                                        endpointLoc.getResultFlightServerAddr().port);
                            }
                        } else {
                            location = Location.forGrpcInsecure(endpointLoc.getResultFlightServerAddr().hostname,
                                    endpointLoc.getResultFlightServerAddr().port);
                        }
                        // By default, the query results of all BE nodes will be aggregated to one BE node.
                        // ADBC Client will only receive one endpoint and pull data from the BE node
                        // corresponding to this endpoint.
                        // `set global enable_parallel_result_sink=true;` to allow each BE to return query results
                        // separately. ADBC Client will receive multiple endpoints and pull data from each endpoint.
                        endpoints.add(new FlightEndpoint(ticket, location));
                    }
                    // TODO Set in BE callback after query end, Client will not callback.
                    return new FlightInfo(flightSQLConnectProcessor.getArrowSchema(), descriptor, endpoints, -1, -1);
                }
            }
        } catch (Throwable e) {
            String errMsg = "get flight info statement failed, " + e.getMessage() + ", " + Util.getRootCauseMessage(e)
                    + ", error code: " + connectContext.getState().getErrorCode() + ", error msg: "
                    + connectContext.getState().getErrorMessage();
            LOG.error(errMsg, e);
            throw CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException();
        } finally {
            connectContext.setCommand(MysqlCommand.COM_SLEEP);
        }
    }

    @Override
    public FlightInfo getFlightInfoStatement(final CommandStatementQuery request, final CallContext context,
            final FlightDescriptor descriptor) {
        try {
            ConnectContext connectContext = flightSessionsManager.getConnectContext(context.peerIdentity());
            return executeQueryStatement(context.peerIdentity(), connectContext, request.getQuery(), descriptor);
        } catch (Throwable e) {
            String errMsg = "get flight info statement failed, " + e.getMessage();
            LOG.error(errMsg, e);
            throw CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException();
        }
    }

    @Override
    public FlightInfo getFlightInfoPreparedStatement(final CommandPreparedStatementQuery command,
            final CallContext context, final FlightDescriptor descriptor) {
        String[] handleParts = command.getPreparedStatementHandle().toStringUtf8().split(":");
        String executedPeerIdentity = handleParts[0];
        String preparedStatementId = handleParts[1];
        ConnectContext connectContext = flightSessionsManager.getConnectContext(executedPeerIdentity);
        return executeQueryStatement(executedPeerIdentity, connectContext,
                connectContext.getPreparedQuery(preparedStatementId), descriptor);
    }

    @Override
    public SchemaResult getSchemaStatement(final CommandStatementQuery command, final CallContext context,
            final FlightDescriptor descriptor) {
        throw CallStatus.UNIMPLEMENTED.withDescription("getSchemaStatement unimplemented").toRuntimeException();
    }

    @Override
    public void close() throws Exception {
        AutoCloseables.close(rootAllocator);
    }

    @Override
    public void listFlights(CallContext context, Criteria criteria, StreamListener<FlightInfo> listener) {
        throw CallStatus.UNIMPLEMENTED.withDescription("listFlights unimplemented").toRuntimeException();
    }

    private ActionCreatePreparedStatementResult buildCreatePreparedStatementResult(ByteString handle,
            Schema parameterSchema, Schema metaData) {
        Preconditions.checkState(!Objects.isNull(metaData));
        final ByteString bytes = Objects.isNull(parameterSchema) ? ByteString.EMPTY
                : ByteString.copyFrom(serializeMetadata(parameterSchema));
        return ActionCreatePreparedStatementResult.newBuilder()
                .setDatasetSchema(ByteString.copyFrom(serializeMetadata(metaData))).setParameterSchema(bytes)
                .setPreparedStatementHandle(handle).build();
    }

    @Override
    public void createPreparedStatement(final ActionCreatePreparedStatementRequest request, final CallContext context,
            final StreamListener<Result> listener) {
        // TODO can only execute complete SQL, not support SQL parameters.
        // For Python: the Python code will try to create a prepared statement (this is to fit DBAPI, IIRC) and
        // if the server raises any error except for NotImplemented it will fail. (If it gets NotImplemented,
        // it will ignore and execute without a prepared statement.) see: https://github.com/apache/arrow/issues/38786
        executorService.submit(() -> {
            ConnectContext connectContext = flightSessionsManager.getConnectContext(context.peerIdentity());
            try {
                connectContext.setCommand(MysqlCommand.COM_QUERY);
                final String query = request.getQuery();
                String preparedStatementId = UUID.randomUUID().toString();
                final ByteString handle = ByteString.copyFromUtf8(context.peerIdentity() + ":" + preparedStatementId);
                connectContext.addPreparedQuery(preparedStatementId, query);

                VectorSchemaRoot emptyVectorSchemaRoot = new VectorSchemaRoot(new ArrayList<>(), new ArrayList<>());
                final Schema parameterSchema = emptyVectorSchemaRoot.getSchema();
                // TODO FE does not have the ability to convert root fragment output expr into arrow schema.
                // However, the metaData schema returned by createPreparedStatement is usually not used by the client,
                // but it cannot be empty, otherwise it will be mistaken by the client as an updata statement.
                // see: https://github.com/apache/arrow/issues/38911
                Schema metaData = connectContext.getFlightSqlChannel()
                        .createOneOneSchemaRoot("ResultMeta", "UNIMPLEMENTED").getSchema();
                listener.onNext(new Result(
                        Any.pack(buildCreatePreparedStatementResult(handle, parameterSchema, metaData)).toByteArray()));
            } catch (Exception e) {
                String errMsg = "create prepared statement failed, " + e.getMessage() + ", " + Util.getRootCauseMessage(
                        e) + ", error code: " + connectContext.getState().getErrorCode() + ", error msg: "
                        + connectContext.getState().getErrorMessage();
                LOG.error(errMsg, e);
                listener.onError(CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException());
                return;
            } catch (final Throwable t) {
                listener.onError(CallStatus.INTERNAL.withDescription("Unknown error: " + t).toRuntimeException());
                return;
            } finally {
                connectContext.setCommand(MysqlCommand.COM_SLEEP);
            }
            listener.onCompleted();
        });
    }

    @Override
    public void doExchange(CallContext context, FlightStream reader, ServerStreamListener writer) {
        throw CallStatus.UNIMPLEMENTED.withDescription("doExchange unimplemented").toRuntimeException();
    }

    @Override
    public Runnable acceptPutStatement(CommandStatementUpdate command, CallContext context, FlightStream flightStream,
            StreamListener<PutResult> ackStream) {
        throw CallStatus.UNIMPLEMENTED.withDescription("acceptPutStatement unimplemented").toRuntimeException();
    }

    @Override
    public Runnable acceptPutPreparedStatementUpdate(CommandPreparedStatementUpdate command, CallContext context,
            FlightStream flightStream, StreamListener<PutResult> ackStream) {
        return () -> {
            try {
                while (flightStream.next()) {
                    final VectorSchemaRoot root = flightStream.getRoot();
                    final int rowCount = root.getRowCount();
                    // TODO support update
                    Preconditions.checkState(rowCount == 0);

                    final int recordCount = -1;
                    final DoPutUpdateResult build = DoPutUpdateResult.newBuilder().setRecordCount(recordCount).build();
                    try (final ArrowBuf buffer = rootAllocator.buffer(build.getSerializedSize())) {
                        buffer.writeBytes(build.toByteArray());
                        ackStream.onNext(PutResult.metadata(buffer));
                    }
                }
                ackStream.onCompleted();
            } catch (Throwable e) {
                String errMsg = "acceptPutPreparedStatementUpdate failed, " + e.getMessage() + ", "
                        + Util.getRootCauseMessage(e);
                LOG.error(errMsg, e);
                throw CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException();
            }
        };
    }

    @Override
    public Runnable acceptPutPreparedStatementQuery(CommandPreparedStatementQuery command, CallContext context,
            FlightStream flightStream, StreamListener<PutResult> ackStream) {
        throw CallStatus.UNIMPLEMENTED.withDescription("acceptPutPreparedStatementQuery unimplemented")
                .toRuntimeException();
    }

    @Override
    public FlightInfo getFlightInfoSqlInfo(final CommandGetSqlInfo request, final CallContext context,
            final FlightDescriptor descriptor) {
        return getFlightInfoForSchema(request, descriptor, Schemas.GET_SQL_INFO_SCHEMA);
    }

    @Override
    public void getStreamSqlInfo(final CommandGetSqlInfo command, final CallContext context,
            final ServerStreamListener listener) {
        this.sqlInfoBuilder.send(command.getInfoList(), listener);
    }

    @Override
    public FlightInfo getFlightInfoTypeInfo(CommandGetXdbcTypeInfo request, CallContext context,
            FlightDescriptor descriptor) {
        return getFlightInfoForSchema(request, descriptor, Schemas.GET_TYPE_INFO_SCHEMA);
    }

    @Override
    public void getStreamTypeInfo(CommandGetXdbcTypeInfo request, CallContext context, ServerStreamListener listener) {
        throw CallStatus.UNIMPLEMENTED.withDescription("getStreamTypeInfo unimplemented").toRuntimeException();
    }

    @Override
    public FlightInfo getFlightInfoCatalogs(final CommandGetCatalogs request, final CallContext context,
            final FlightDescriptor descriptor) {
        return getFlightInfoForSchema(request, descriptor, Schemas.GET_CATALOGS_SCHEMA);
    }

    @Override
    public void getStreamCatalogs(final CallContext context, final ServerStreamListener listener) {
        try {
            ConnectContext connectContext = flightSessionsManager.getConnectContext(context.peerIdentity());
            FlightSqlSchemaHelper flightSqlSchemaHelper = new FlightSqlSchemaHelper(connectContext);
            final Schema schema = Schemas.GET_CATALOGS_SCHEMA;

            try (final VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.create(schema, rootAllocator)) {
                listener.start(vectorSchemaRoot);
                vectorSchemaRoot.allocateNew();
                flightSqlSchemaHelper.getCatalogs(vectorSchemaRoot);
                listener.putNext();
                listener.completed();
            }
        } catch (final Throwable e) {
            handleStreamException(e, "", listener);
        }
    }

    @Override
    public FlightInfo getFlightInfoSchemas(final CommandGetDbSchemas request, final CallContext context,
            final FlightDescriptor descriptor) {
        return getFlightInfoForSchema(request, descriptor, Schemas.GET_SCHEMAS_SCHEMA);
    }

    @Override
    public void getStreamSchemas(final CommandGetDbSchemas command, final CallContext context,
            final ServerStreamListener listener) {
        try {
            ConnectContext connectContext = flightSessionsManager.getConnectContext(context.peerIdentity());
            FlightSqlSchemaHelper flightSqlSchemaHelper = new FlightSqlSchemaHelper(connectContext);
            flightSqlSchemaHelper.setParameterForGetDbSchemas(command);
            final Schema schema = Schemas.GET_SCHEMAS_SCHEMA;

            try (VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.create(schema, rootAllocator)) {
                listener.start(vectorSchemaRoot);
                vectorSchemaRoot.allocateNew();
                flightSqlSchemaHelper.getSchemas(vectorSchemaRoot);
                listener.putNext();
                listener.completed();
            }
        } catch (final Throwable e) {
            handleStreamException(e, "", listener);
        }
    }

    @Override
    public FlightInfo getFlightInfoTables(final CommandGetTables request, final CallContext context,
            final FlightDescriptor descriptor) {
        Schema schemaToUse = Schemas.GET_TABLES_SCHEMA;
        if (!request.getIncludeSchema()) {
            schemaToUse = Schemas.GET_TABLES_SCHEMA_NO_SCHEMA;
        }
        return getFlightInfoForSchema(request, descriptor, schemaToUse);
    }

    @Override
    public void getStreamTables(final CommandGetTables command, final CallContext context,
            final ServerStreamListener listener) {
        try {
            ConnectContext connectContext = flightSessionsManager.getConnectContext(context.peerIdentity());
            FlightSqlSchemaHelper flightSqlSchemaHelper = new FlightSqlSchemaHelper(connectContext);
            flightSqlSchemaHelper.setParameterForGetTables(command);
            final Schema schema = command.getIncludeSchema() ? Schemas.GET_TABLES_SCHEMA
                    : Schemas.GET_TABLES_SCHEMA_NO_SCHEMA;

            try (VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.create(schema, rootAllocator)) {
                listener.start(vectorSchemaRoot);
                vectorSchemaRoot.allocateNew();
                flightSqlSchemaHelper.getTables(vectorSchemaRoot);
                listener.putNext();
                listener.completed();
            }
        } catch (final Throwable e) {
            handleStreamException(e, "", listener);
        }
    }

    @Override
    public FlightInfo getFlightInfoTableTypes(final CommandGetTableTypes request, final CallContext context,
            final FlightDescriptor descriptor) {
        return getFlightInfoForSchema(request, descriptor, Schemas.GET_TABLE_TYPES_SCHEMA);
    }

    @Override
    public void getStreamTableTypes(final CallContext context, final ServerStreamListener listener) {
        throw CallStatus.UNIMPLEMENTED.withDescription("getStreamTableTypes unimplemented").toRuntimeException();
    }

    @Override
    public FlightInfo getFlightInfoPrimaryKeys(final CommandGetPrimaryKeys request, final CallContext context,
            final FlightDescriptor descriptor) {
        return getFlightInfoForSchema(request, descriptor, Schemas.GET_PRIMARY_KEYS_SCHEMA);
    }

    @Override
    public void getStreamPrimaryKeys(final CommandGetPrimaryKeys command, final CallContext context,
            final ServerStreamListener listener) {
        throw CallStatus.UNIMPLEMENTED.withDescription("getStreamPrimaryKeys unimplemented").toRuntimeException();
    }

    @Override
    public FlightInfo getFlightInfoExportedKeys(final CommandGetExportedKeys request, final CallContext context,
            final FlightDescriptor descriptor) {
        return getFlightInfoForSchema(request, descriptor, Schemas.GET_EXPORTED_KEYS_SCHEMA);
    }

    @Override
    public void getStreamExportedKeys(final CommandGetExportedKeys command, final CallContext context,
            final ServerStreamListener listener) {
        throw CallStatus.UNIMPLEMENTED.withDescription("getStreamExportedKeys unimplemented").toRuntimeException();
    }

    @Override
    public FlightInfo getFlightInfoImportedKeys(final CommandGetImportedKeys request, final CallContext context,
            final FlightDescriptor descriptor) {
        return getFlightInfoForSchema(request, descriptor, Schemas.GET_IMPORTED_KEYS_SCHEMA);
    }

    @Override
    public void getStreamImportedKeys(final CommandGetImportedKeys command, final CallContext context,
            final ServerStreamListener listener) {
        throw CallStatus.UNIMPLEMENTED.withDescription("getStreamImportedKeys unimplemented").toRuntimeException();
    }

    @Override
    public FlightInfo getFlightInfoCrossReference(CommandGetCrossReference request, CallContext context,
            FlightDescriptor descriptor) {
        return getFlightInfoForSchema(request, descriptor, Schemas.GET_CROSS_REFERENCE_SCHEMA);
    }

    @Override
    public void getStreamCrossReference(CommandGetCrossReference command, CallContext context,
            ServerStreamListener listener) {
        throw CallStatus.UNIMPLEMENTED.withDescription("getStreamCrossReference unimplemented").toRuntimeException();
    }

    @Override
    public void closeSession(CloseSessionRequest request, final CallContext context,
            final StreamListener<CloseSessionResult> listener) {
        // https://github.com/apache/arrow-adbc/issues/2821
        // currently FlightSqlConnection does not provide a separate interface for external calls to
        // FlightSqlClient::closeSession(), nor will it automatically call closeSession
        // when FlightSqlConnection::close(). Python flight sql Cursor.close() will call closeSession().
        // Neither C++ nor Java seem to have similar behavior.
        try {
            flightSessionsManager.closeConnectContext(context.peerIdentity());
        } catch (final Throwable e) {
            LOG.error("closeSession failed", e);
            listener.onError(
                    CallStatus.INTERNAL.withDescription("closeSession failed").withCause(e).toRuntimeException());
        }
        listener.onNext(new CloseSessionResult(CloseSessionResult.Status.CLOSED));
        listener.onCompleted();
    }

    private <T extends Message> FlightInfo getFlightInfoForSchema(final T request, final FlightDescriptor descriptor,
            final Schema schema) {
        final Ticket ticket = new Ticket(Any.pack(request).toByteArray());
        final List<FlightEndpoint> endpoints = Collections.singletonList(new FlightEndpoint(ticket, location));

        return new FlightInfo(schema, descriptor, endpoints, -1, -1);
    }

    private static void handleStreamException(Throwable e, String errMsg, ServerStreamListener listener) {
        LOG.error(errMsg, e);
        listener.error(CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException());
        throw CallStatus.INTERNAL.withDescription(errMsg).withCause(e).toRuntimeException();
    }
}