MysqlConnectProcessor.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.qe;
import org.apache.doris.analysis.StatementBase;
import org.apache.doris.analysis.UserIdentity;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.MysqlColType;
import org.apache.doris.cloud.catalog.CloudEnv;
import org.apache.doris.common.AuthenticationException;
import org.apache.doris.common.Config;
import org.apache.doris.common.ConnectionException;
import org.apache.doris.common.DdlException;
import org.apache.doris.common.ErrorCode;
import org.apache.doris.common.ErrorReport;
import org.apache.doris.datasource.CatalogIf;
import org.apache.doris.datasource.InternalCatalog;
import org.apache.doris.mysql.MysqlChannel;
import org.apache.doris.mysql.MysqlCommand;
import org.apache.doris.mysql.MysqlHandshakePacket;
import org.apache.doris.mysql.MysqlProto;
import org.apache.doris.mysql.MysqlSerializer;
import org.apache.doris.mysql.privilege.Auth;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.glue.LogicalPlanAdapter;
import org.apache.doris.nereids.trees.expressions.Placeholder;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.PlaceholderId;
import org.apache.doris.nereids.trees.plans.commands.ExecuteCommand;
import org.apache.doris.nereids.trees.plans.commands.PrepareCommand;
import org.apache.doris.thrift.TUniqueId;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.collect.Lists;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.AsynchronousCloseException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
/**
* Process one mysql connection, receive one packet, process, send one packet.
*/
public class MysqlConnectProcessor extends ConnectProcessor {
private static final Logger LOG = LogManager.getLogger(MysqlConnectProcessor.class);
private ByteBuffer packetBuf;
public MysqlConnectProcessor(ConnectContext context) {
super(context);
connectType = ConnectType.MYSQL;
}
// COM_INIT_DB: change current database of this session.
private void handleInitDb() {
String fullDbName = new String(packetBuf.array(), 1, packetBuf.limit() - 1);
handleInitDb(fullDbName);
}
private void handleStmtClose() {
packetBuf = packetBuf.order(ByteOrder.LITTLE_ENDIAN);
int stmtId = packetBuf.getInt();
handleStmtClose(stmtId);
}
private String getPacket() {
byte[] bytes = packetBuf.array();
StringBuilder printB = new StringBuilder();
for (byte b : bytes) {
if (Character.isLetterOrDigit((char) b & 0xFF)) {
char x = (char) b;
printB.append(x);
} else {
printB.append("0x" + Integer.toHexString(b & 0xFF));
}
printB.append(" ");
}
return printB.toString().substring(0, 200);
}
private void debugPacket() {
if (LOG.isDebugEnabled()) {
LOG.debug("debug packet {}", getPacket());
}
}
private String getHexStr(ByteBuffer packetBuf) {
byte[] bytes = packetBuf.array();
StringBuilder hex = new StringBuilder();
for (int i = packetBuf.position(); i < packetBuf.limit(); ++i) {
hex.append(String.format("%02X ", bytes[i]));
}
return hex.toString();
}
@Override
protected void handleExecute(PrepareCommand prepareCommand, long stmtId, PreparedStatementContext prepCtx,
ByteBuffer packetBuf, TUniqueId queryId) {
int paramCount = prepareCommand.placeholderCount();
LOG.debug("execute prepared statement {}, paramCount {}", stmtId, paramCount);
// null bitmap
String stmtStr = "";
try {
StatementContext statementContext = prepCtx.statementContext;
if (paramCount > 0) {
if (LOG.isDebugEnabled()) {
LOG.debug("execute param buf: {}, array: {}", packetBuf, getHexStr(packetBuf));
}
if (!ctx.isProxy()) {
ctx.setPrepareExecuteBuffer(packetBuf.duplicate());
}
byte[] nullbitmapData = new byte[(paramCount + 7) / 8];
packetBuf.get(nullbitmapData);
// new_params_bind_flag
if ((int) packetBuf.get() != 0) {
List<Placeholder> typedPlaceholders = new ArrayList<>();
// parse params's types
for (int i = 0; i < paramCount; ++i) {
int typeCode = packetBuf.getChar();
LOG.debug("code {}", typeCode);
// assign type to placeholders
typedPlaceholders.add(
prepareCommand.getPlaceholders().get(i).withNewMysqlColType(typeCode));
}
// rewrite with new prepared statment with type info in placeholders
prepCtx.command = prepareCommand.withPlaceholders(typedPlaceholders);
prepareCommand = (PrepareCommand) prepCtx.command;
}
// parse param data
for (int i = 0; i < paramCount; ++i) {
PlaceholderId exprId = prepareCommand.getPlaceholders().get(i).getPlaceholderId();
if (isNull(nullbitmapData, i)) {
statementContext.getIdToPlaceholderRealExpr().put(exprId,
new org.apache.doris.nereids.trees.expressions.literal.NullLiteral());
continue;
}
MysqlColType type = prepareCommand.getPlaceholders().get(i).getMysqlColType();
boolean isUnsigned = prepareCommand.getPlaceholders().get(i).isUnsigned();
Literal l = Literal.getLiteralByMysqlType(type, isUnsigned, packetBuf);
statementContext.getIdToPlaceholderRealExpr().put(exprId, l);
}
}
ExecuteCommand executeStmt = new ExecuteCommand(String.valueOf(stmtId), prepareCommand, statementContext);
// TODO set real origin statement
if (LOG.isDebugEnabled()) {
LOG.debug("executeStmt {}", executeStmt);
}
StatementBase stmt = new LogicalPlanAdapter(executeStmt, statementContext);
stmt.setOrigStmt(prepareCommand.getOriginalStmt());
executor = new StmtExecutor(ctx, stmt);
ctx.setExecutor(executor);
if (null != queryId) {
executor.execute(queryId);
} else {
executor.execute();
}
if (ctx.getSessionVariable().isEnablePreparedStmtAuditLog()) {
stmtStr = executeStmt.toSql();
stmtStr = stmtStr + " /*originalSql = " + prepareCommand.getOriginalStmt().originStmt + "*/";
}
} catch (Throwable e) {
// Catch all throwable.
// If reach here, maybe doris bug.
LOG.warn("Process one query failed because unknown reason: ", e);
ctx.getState().setError(ErrorCode.ERR_UNKNOWN_ERROR,
e.getClass().getSimpleName() + ", msg: " + e.getMessage());
}
if (ctx.getSessionVariable().isEnablePreparedStmtAuditLog()) {
auditAfterExec(stmtStr, executor.getParsedStmt(), executor.getQueryStatisticsForAuditLog(), true);
}
}
// process COM_EXECUTE, parse binary row data
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html
private void handleExecute() {
if (LOG.isDebugEnabled()) {
debugPacket();
}
packetBuf = packetBuf.order(ByteOrder.LITTLE_ENDIAN);
// parse stmt_id, flags, params
int stmtId = packetBuf.getInt();
// flag
packetBuf.get();
// iteration_count always 1,
packetBuf.getInt();
if (LOG.isDebugEnabled()) {
LOG.debug("execute prepared statement {}", stmtId);
}
ctx.setStartTime();
// nereids
PreparedStatementContext preparedStatementContext = ctx.getPreparedStementContext(String.valueOf(stmtId));
if (preparedStatementContext == null) {
LOG.warn("No such statement in context, stmtId:{}", stmtId);
ctx.getState().setError(ErrorCode.ERR_UNKNOWN_COM_ERROR,
"msg: Not supported such prepared statement");
return;
}
handleExecute(preparedStatementContext.command, stmtId, preparedStatementContext, packetBuf, null);
}
// Process COM_QUERY statement,
private void handleQuery() throws ConnectionException {
// convert statement to Java string
byte[] bytes = packetBuf.array();
int ending = packetBuf.limit() - 1;
while (ending >= 1 && bytes[ending] == '\0') {
ending--;
}
String originStmt = new String(bytes, 1, ending, StandardCharsets.UTF_8);
handleQuery(originStmt);
}
private void dispatch() throws IOException {
int code = packetBuf.get();
MysqlCommand command = MysqlCommand.fromCode(code);
if (command == null) {
ErrorReport.report(ErrorCode.ERR_UNKNOWN_COM_ERROR);
ctx.getState().setError(ErrorCode.ERR_UNKNOWN_COM_ERROR, "Unknown command(" + code + ")");
LOG.warn("Unknown command({})", code);
return;
}
if (LOG.isDebugEnabled()) {
LOG.debug("handle command {}", command);
}
ctx.setCommand(command);
ctx.setStartTime();
switch (command) {
case COM_INIT_DB:
handleInitDb();
break;
case COM_QUIT:
// COM_QUIT: set killed flag and then return OK packet.
handleQuit();
break;
case COM_QUERY:
case COM_STMT_PREPARE:
handleQuery();
break;
case COM_STMT_EXECUTE:
handleExecute();
break;
case COM_FIELD_LIST:
handleFieldList();
break;
case COM_PING:
// process COM_PING statement, do nothing, just return one OK packet.
handlePing();
break;
case COM_STATISTICS:
handleStatistics();
break;
case COM_DEBUG:
handleDebug();
break;
case COM_CHANGE_USER:
handleChangeUser();
break;
case COM_STMT_RESET:
handleStmtReset();
break;
case COM_STMT_CLOSE:
handleStmtClose();
break;
case COM_SET_OPTION:
handleSetOption();
break;
case COM_RESET_CONNECTION:
handleResetConnection();
break;
default:
ctx.getState().setError(ErrorCode.ERR_UNKNOWN_COM_ERROR, "Unsupported command(" + command + ")");
LOG.warn("Unsupported command(" + command + ")");
if (command.equals(MysqlCommand.COM_SLEEP)) {
LOG.warn("COM_SLEEP packet: [{}]", getPacket());
}
break;
}
}
private void handleFieldList() throws ConnectionException {
String tableName = new String(MysqlProto.readNulTerminateString(packetBuf), StandardCharsets.UTF_8);
handleFieldList(tableName);
}
private void handleChangeUser() throws IOException {
// Random bytes generated when creating connection.
byte[] authPluginData = getConnectContext().getAuthPluginData();
Preconditions.checkNotNull(authPluginData, "Auth plugin data is null.");
String userName = new String(MysqlProto.readNulTerminateString(packetBuf));
int passwordLen = MysqlProto.readInt1(packetBuf);
byte[] password = MysqlProto.readFixedString(packetBuf, passwordLen);
String db = new String(MysqlProto.readNulTerminateString(packetBuf));
// Read the character set.
MysqlProto.readInt2(packetBuf);
String authPluginName = new String(MysqlProto.readNulTerminateString(packetBuf));
// Send Protocol::AuthSwitchRequest to client if auth plugin name is not mysql_native_password
if (!MysqlHandshakePacket.AUTH_PLUGIN_NAME.equals(authPluginName)) {
MysqlChannel channel = ctx.mysqlChannel;
MysqlSerializer serializer = MysqlSerializer.newInstance();
serializer.writeInt1((byte) 0xfe);
serializer.writeNulTerminateString(MysqlHandshakePacket.AUTH_PLUGIN_NAME);
serializer.writeBytes(authPluginData);
serializer.writeInt1(0);
channel.sendAndFlush(serializer.toByteBuffer());
// Server receive auth switch response packet from client.
ByteBuffer authSwitchResponse = channel.fetchOnePacket();
int length = authSwitchResponse.limit();
password = new byte[length];
System.arraycopy(authSwitchResponse.array(), 0, password, 0, length);
}
// For safety, not allowed to change to root or admin.
if (Auth.ROOT_USER.equals(userName) || Auth.ADMIN_USER.equals(userName)) {
ctx.getState().setError(ErrorCode.ERR_ACCESS_DENIED_ERROR, "Change to root or admin is forbidden");
return;
}
// Check password.
List<UserIdentity> currentUserIdentity = Lists.newArrayList();
try {
Env.getCurrentEnv().getAuth()
.checkPassword(userName, ctx.remoteIP, password, authPluginData, currentUserIdentity);
} catch (AuthenticationException e) {
ctx.getState().setError(ErrorCode.ERR_ACCESS_DENIED_ERROR, "Authentication failed.");
return;
}
ctx.setCurrentUserIdentity(currentUserIdentity.get(0));
ctx.setQualifiedUser(userName);
// Change default db if set.
if (Strings.isNullOrEmpty(db)) {
ctx.changeDefaultCatalog(InternalCatalog.INTERNAL_CATALOG_NAME);
} else {
String catalogName = null;
String dbName = null;
String[] dbNames = db.split("\\.");
if (dbNames.length == 1) {
dbName = db;
} else if (dbNames.length == 2) {
catalogName = dbNames[0];
dbName = dbNames[1];
} else if (dbNames.length > 2) {
ctx.getState().setError(ErrorCode.ERR_BAD_DB_ERROR, "Only one dot can be in the name: " + db);
return;
}
if (Config.isCloudMode()) {
try {
dbName = ((CloudEnv) Env.getCurrentEnv()).analyzeCloudCluster(dbName, ctx);
} catch (DdlException e) {
ctx.getState().setError(e.getMysqlErrorCode(), e.getMessage());
return;
}
}
// check catalog and db exists
if (catalogName != null) {
CatalogIf catalogIf = ctx.getEnv().getCatalogMgr().getCatalog(catalogName);
if (catalogIf == null) {
ctx.getState().setError(ErrorCode.ERR_BAD_DB_ERROR, "No match catalog in doris: " + db);
return;
}
if (catalogIf.getDbNullable(dbName) == null) {
ctx.getState().setError(ErrorCode.ERR_BAD_DB_ERROR, "No match database in doris: " + db);
return;
}
}
try {
if (catalogName != null) {
ctx.getEnv().changeCatalog(ctx, catalogName);
}
Env.getCurrentEnv().changeDb(ctx, dbName);
} catch (DdlException e) {
ctx.getState().setError(e.getMysqlErrorCode(), e.getMessage());
return;
}
}
ctx.getState().setOk();
}
private void handleSetOption() {
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_set_option.html
int optionOperation = MysqlProto.readInt2(packetBuf);
LOG.debug("option_operation {}", optionOperation);
// Do nothing for now.
// https://dev.mysql.com/doc/c-api/8.0/en/mysql-set-server-option.html
ctx.getState().setOk();
}
// Process a MySQL request
public void processOnce() throws IOException {
// set status of query to OK.
ctx.getState().reset();
ctx.setGroupCommit(false);
executor = null;
// reset sequence id of MySQL protocol
final MysqlChannel channel = ctx.getMysqlChannel();
channel.setSequenceId(0);
// read packet from channel
try {
packetBuf = channel.fetchOnePacket();
if (packetBuf == null) {
LOG.warn("Null packet received from network. remote: {}", channel.getRemoteHostPortString());
throw new IOException("Error happened when receiving packet.");
}
if (!packetBuf.hasRemaining()) {
LOG.info("No more data to be read. Close connection. remote={}", channel.getRemoteHostPortString());
ctx.setKilled();
return;
}
} catch (AsynchronousCloseException e) {
// when this happened, timeout checker close this channel
// killed flag in ctx has been already set, just return
return;
}
// dispatch
dispatch();
// finalize
finalizeCommand();
ctx.setCommand(MysqlCommand.COM_SLEEP);
ctx.clear();
executor = null;
}
public void loop() {
while (!ctx.isKilled()) {
try {
processOnce();
} catch (Exception e) {
// TODO(zhaochun): something wrong
LOG.warn("Exception happened in one session(" + ctx + ").", e);
ctx.setKilled();
break;
}
}
}
}