FlightSqlConnectProcessor.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.service.arrowflight;
import org.apache.doris.analysis.Expr;
import org.apache.doris.common.ConnectionException;
import org.apache.doris.common.ErrorCode;
import org.apache.doris.common.ErrorReport;
import org.apache.doris.common.Status;
import org.apache.doris.common.util.DebugUtil;
import org.apache.doris.mysql.MysqlCommand;
import org.apache.doris.proto.InternalService;
import org.apache.doris.proto.Types;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.ConnectProcessor;
import org.apache.doris.qe.StmtExecutor;
import org.apache.doris.rpc.BackendServiceProxy;
import org.apache.doris.rpc.RpcException;
import org.apache.doris.service.arrowflight.results.FlightSqlEndpointsLocation;
import org.apache.doris.thrift.TNetworkAddress;
import org.apache.doris.thrift.TStatusCode;
import org.apache.doris.thrift.TUniqueId;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.io.ByteArrayInputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
/**
* Process one flgiht sql connection.
* <p>
* Must use try-with-resources.
*/
public class FlightSqlConnectProcessor extends ConnectProcessor implements AutoCloseable {
private static final Logger LOG = LogManager.getLogger(FlightSqlConnectProcessor.class);
private Schema arrowSchema;
public FlightSqlConnectProcessor(ConnectContext context) {
super(context);
connectType = ConnectType.ARROW_FLIGHT_SQL;
context.setThreadLocalInfo();
context.setReturnResultFromLocal(true);
}
public Schema getArrowSchema() {
return arrowSchema;
}
public void prepare(MysqlCommand command) {
// set status of query to OK.
ctx.getState().reset();
executor = null;
if (command == null) {
ErrorReport.report(ErrorCode.ERR_UNKNOWN_COM_ERROR);
ctx.getState().setError(ErrorCode.ERR_UNKNOWN_COM_ERROR, "Unknown command(" + command.toString() + ")");
LOG.warn("Unknown command(" + command + ")");
return;
}
if (LOG.isDebugEnabled()) {
LOG.debug("arrow flight sql handle command {}", command);
}
ctx.setCommand(command);
ctx.setStartTime();
}
public void handleQuery(String query) throws ConnectionException {
MysqlCommand command = MysqlCommand.COM_QUERY;
prepare(command);
ctx.setRunningQuery(query);
super.handleQuery(query);
}
// TODO
// private void handleInitDb() {
// handleInitDb(fullDbName);
// }
// TODO
// private void handleFieldList() {
// handleFieldList(tableName);
// }
public void fetchArrowFlightSchema(int timeoutMs) {
if (ctx.getFlightSqlEndpointsLocations().isEmpty()) {
throw new RuntimeException("fetch arrow flight schema failed, no FlightSqlEndpointsLocations.");
}
for (FlightSqlEndpointsLocation endpointLoc : ctx.getFlightSqlEndpointsLocations()) {
TNetworkAddress address = endpointLoc.getResultInternalServiceAddr();
TUniqueId tid = endpointLoc.getFinstId();
ArrayList<Expr> resultOutputExprs = endpointLoc.getResultOutputExprs();
Types.PUniqueId queryId = Types.PUniqueId.newBuilder().setHi(tid.hi).setLo(tid.lo).build();
try {
InternalService.PFetchArrowFlightSchemaRequest request
= InternalService.PFetchArrowFlightSchemaRequest.newBuilder().setFinstId(queryId).build();
Future<InternalService.PFetchArrowFlightSchemaResult> future = BackendServiceProxy.getInstance()
.fetchArrowFlightSchema(address, request);
InternalService.PFetchArrowFlightSchemaResult pResult;
pResult = future.get(timeoutMs, TimeUnit.MILLISECONDS);
if (pResult == null) {
throw new RuntimeException(
String.format("fetch arrow flight schema timeout, queryId: %s", DebugUtil.printId(tid)));
}
Status resultStatus = new Status(pResult.getStatus());
if (resultStatus.getErrorCode() != TStatusCode.OK) {
throw new RuntimeException(
String.format("fetch arrow flight schema failed, queryId: %s, errmsg: %s",
DebugUtil.printId(tid), resultStatus));
}
TNetworkAddress resultPublicAccessAddr = new TNetworkAddress();
if (pResult.hasBeArrowFlightIp()) {
resultPublicAccessAddr.setHostname(pResult.getBeArrowFlightIp().toStringUtf8());
}
if (pResult.hasBeArrowFlightPort()) {
resultPublicAccessAddr.setPort(pResult.getBeArrowFlightPort());
}
endpointLoc.setResultPublicAccessAddr(resultPublicAccessAddr);
if (pResult.hasSchema() && pResult.getSchema().size() > 0) {
RootAllocator rootAllocator = new RootAllocator(Integer.MAX_VALUE);
ArrowStreamReader arrowStreamReader = new ArrowStreamReader(
new ByteArrayInputStream(pResult.getSchema().toByteArray()), rootAllocator);
try {
Schema schema;
VectorSchemaRoot root = arrowStreamReader.getVectorSchemaRoot();
List<FieldVector> fieldVectors = root.getFieldVectors();
if (fieldVectors.size() != resultOutputExprs.size()) {
throw new RuntimeException(
String.format("Schema size %s' is not equal to arrow field size %s, queryId: %s.",
fieldVectors.size(), resultOutputExprs.size(), DebugUtil.printId(tid)));
}
schema = root.getSchema();
if (arrowSchema == null) {
arrowSchema = schema;
} else if (!arrowSchema.equals(schema)) {
throw new RuntimeException(String.format(
"The schema returned by results BE is different, first schema: %s, "
+ "new schema: %s, queryId: %s,backend: %s", arrowSchema, schema,
DebugUtil.printId(tid), address));
}
} catch (Exception e) {
throw new RuntimeException("Read Arrow Flight Schema failed.", e);
}
} else {
throw new RuntimeException(
String.format("get empty arrow flight schema, queryId: %s", DebugUtil.printId(tid)));
}
} catch (RpcException e) {
throw new RuntimeException(
String.format("arrow flight schema fetch catch rpc exception, queryId: %s,backend: %s",
DebugUtil.printId(tid), address), e);
} catch (InterruptedException e) {
throw new RuntimeException(
String.format("arrow flight schema future get interrupted exception, queryId: %s,backend: %s",
DebugUtil.printId(tid), address), e);
} catch (ExecutionException e) {
throw new RuntimeException(
String.format("arrow flight schema future get execution exception, queryId: %s,backend: %s",
DebugUtil.printId(tid), address), e);
} catch (TimeoutException e) {
throw new RuntimeException(String.format("arrow flight schema fetch timeout, queryId: %s,backend: %s",
DebugUtil.printId(tid), address), e);
}
}
}
@Override
public void close() throws Exception {
ctx.setCommand(MysqlCommand.COM_SLEEP);
ctx.clear();
for (StmtExecutor asynExecutor : returnResultFromRemoteExecutor) {
asynExecutor.finalizeQuery();
}
returnResultFromRemoteExecutor.clear();
executor.finalizeQuery();
ConnectContext.remove();
}
}