MysqlProto.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.mysql;

import org.apache.doris.catalog.Env;
import org.apache.doris.cloud.catalog.CloudEnv;
import org.apache.doris.common.Config;
import org.apache.doris.common.DdlException;
import org.apache.doris.common.ErrorCode;
import org.apache.doris.common.ErrorReport;
import org.apache.doris.datasource.CatalogIf;
import org.apache.doris.datasource.InternalCatalog;
import org.apache.doris.qe.ConnectContext;

import com.google.common.base.Strings;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.io.IOException;
import java.nio.ByteBuffer;

// MySQL protocol util
public class MysqlProto {
    private static final Logger LOG = LogManager.getLogger(MysqlProto.class);
    public static final boolean SERVER_USE_SSL = Config.enable_ssl;


    private static String parseUser(ConnectContext context, byte[] scramble, String user) {
        String usePasswd = scramble.length == 0 ? "NO" : "YES";

        String tmpUser = user;
        if (tmpUser == null || tmpUser.isEmpty()) {
            ErrorReport.report(ErrorCode.ERR_ACCESS_DENIED_ERROR, "anonym@" + context.getRemoteIP(), usePasswd);
            return null;
        }

        // check workload group level. user name may contains workload group level.
        // eg:
        // ...@user_name#HIGH
        // set workload group if it is valid, or just ignore it
        String[] strList = tmpUser.split("#", 2);
        if (strList.length > 1) {
            tmpUser = strList[0];
        }

        context.setQualifiedUser(tmpUser);
        return tmpUser;
    }

    // send response packet(OK/EOF/ERR).
    // before call this function, should set information in state of ConnectContext
    public static void sendResponsePacket(ConnectContext context) throws IOException {
        MysqlChannel channel = context.getMysqlChannel();
        MysqlSerializer serializer = channel.getSerializer();
        MysqlPacket packet = context.getState().toResponsePacket();

        // send response packet to client
        serializer.reset();
        packet.writeTo(serializer);
        channel.sendAndFlush(serializer.toByteBuffer());
    }

    /**
     * negotiate with client, use MySQL protocol
     * server ---handshake---> client
     * server <--- authenticate --- client
     * server --- response(OK/ERR) ---> client
     * Exception:
     * IOException:
     */
    public static boolean negotiate(ConnectContext context) throws IOException {
        MysqlChannel channel = context.getMysqlChannel();
        MysqlSerializer serializer = channel.getSerializer();
        context.getState().setOk();

        // Server send handshake packet to client.
        serializer.reset();
        MysqlHandshakePacket handshakePacket = new MysqlHandshakePacket(context.getConnectionId());
        handshakePacket.writeTo(serializer);
        context.setMysqlHandshakePacket(handshakePacket);
        try {
            channel.sendAndFlush(serializer.toByteBuffer());
        } catch (IOException e) {
            if (LOG.isDebugEnabled()) {
                LOG.debug("Send and flush channel exception, ignore.", e);
            }
            return false;
        }

        // Server receive request packet from client, we need to determine which request type it is.
        ByteBuffer clientRequestPacket = channel.fetchOnePacket();
        MysqlCapability capability = new MysqlCapability(MysqlProto.readLowestInt4(clientRequestPacket));

        // Server receive SSL connection request packet from client.
        ByteBuffer sslConnectionRequest;
        // Server receive authenticate packet from client.
        ByteBuffer handshakeResponse;

        if (capability.isClientUseSsl()) {
            if (LOG.isDebugEnabled()) {
                LOG.debug("client is using ssl connection.");
            }
            // During development, we set SSL mode to true by default.
            if (SERVER_USE_SSL) {
                if (LOG.isDebugEnabled()) {
                    LOG.debug("server is also using ssl connection. Will use ssl mode for data exchange.");
                }
                MysqlSslContext mysqlSslContext = context.getMysqlSslContext();
                mysqlSslContext.init();
                channel.initSslBuffer();
                sslConnectionRequest = clientRequestPacket;
                if (sslConnectionRequest == null) {
                    // receive response failed.
                    return false;
                }
                MysqlSslPacket sslPacket = new MysqlSslPacket();
                if (!sslPacket.readFrom(sslConnectionRequest)) {
                    ErrorReport.report(ErrorCode.ERR_NOT_SUPPORTED_AUTH_MODE);
                    sendResponsePacket(context);
                    return false;
                }
                // try to establish ssl connection.
                try {
                    // set channel to handshake mode to process data packet as ssl packet.
                    channel.setSslHandshaking(true);
                    // The ssl handshake phase still uses plaintext.
                    if (!mysqlSslContext.sslExchange(channel)) {
                        ErrorReport.report(ErrorCode.ERR_NOT_SUPPORTED_AUTH_MODE);
                        sendResponsePacket(context);
                        return false;
                    }
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
                // if the exchange is successful, the channel will switch to ssl communication mode
                // which means all data after this moment will be ciphertext.

                // Set channel mode to ssl mode to handle socket packet in ssl format.
                channel.setSslMode(true);
                if (LOG.isDebugEnabled()) {
                    LOG.debug("switch to ssl mode.");
                }
                handshakeResponse = channel.fetchOnePacket();
            } else {
                handshakeResponse = clientRequestPacket;
            }
        } else {
            handshakeResponse = clientRequestPacket;
        }

        if (handshakeResponse == null) {
            // receive response failed.
            return false;
        }
        if (capability.isDeprecatedEOF()) {
            context.getMysqlChannel().setClientDeprecatedEOF();
        }

        // we do not save client capability to context, so here we save CLIENT_MULTI_STATEMENTS to MysqlChannel
        if (capability.isClientMultiStatements()) {
            context.getMysqlChannel().setClientMultiStatements();
        }

        MysqlAuthPacket authPacket = new MysqlAuthPacket();
        if (!authPacket.readFrom(handshakeResponse)) {
            ErrorReport.report(ErrorCode.ERR_NOT_SUPPORTED_AUTH_MODE);
            sendResponsePacket(context);
            return false;
        }

        // check capability
        if (!MysqlCapability.isCompatible(context.getServerCapability(), authPacket.getCapability())) {
            // TODO: client return capability can not support
            ErrorReport.report(ErrorCode.ERR_NOT_SUPPORTED_AUTH_MODE);
            sendResponsePacket(context);
            return false;
        }

        // change the capability of serializer
        context.setCapability(context.getServerCapability());
        serializer.setCapability(context.getCapability());

        String qualifiedUser = parseUser(context, authPacket.getAuthResponse(), authPacket.getUser());
        if (qualifiedUser == null) {
            sendResponsePacket(context);
            return false;
        }

        //  authenticate
        if (!Env.getCurrentEnv().getAuthenticatorManager()
                .authenticate(context, qualifiedUser, channel, serializer, authPacket, handshakePacket)) {
            return false;
        }

        // try to change catalog, if default_init_catalog inside user property is not 'internal'
        try {
            String userInitCatalog = Env.getCurrentEnv().getAuth().getInitCatalog(context.getQualifiedUser());
            if (userInitCatalog != null && userInitCatalog != InternalCatalog.INTERNAL_CATALOG_NAME) {
                CatalogIf catalogIf = context.getEnv().getCatalogMgr().getCatalog(userInitCatalog);
                if (catalogIf != null) {
                    context.getEnv().changeCatalog(context, userInitCatalog);
                }
            }
        } catch (DdlException e) {
            context.getState().setError(e.getMysqlErrorCode(), e.getMessage());
            sendResponsePacket(context);
            return false;
        }

        // set database
        String db = authPacket.getDb();
        if (!Strings.isNullOrEmpty(db)) {
            String catalogName = null;
            String dbName = null;
            String[] dbNames = db.split("\\.");
            if (dbNames.length == 1) {
                dbName = db;
            } else if (dbNames.length == 2) {
                catalogName = dbNames[0];
                dbName = dbNames[1];
            } else if (dbNames.length > 2) {
                context.getState().setError(ErrorCode.ERR_BAD_DB_ERROR, "Only one dot can be in the name: " + db);
                return false;
            }

            // mysql -d
            if (Config.isCloudMode()) {
                try {
                    dbName = ((CloudEnv) Env.getCurrentEnv()).analyzeCloudCluster(dbName, context);
                } catch (DdlException e) {
                    context.getState().setError(e.getMysqlErrorCode(), e.getMessage());
                    sendResponsePacket(context);
                    return false;
                }

                if (dbName == null || dbName.isEmpty()) {
                    return true;
                }
            }

            String dbFullName = dbName;

            // check catalog and db exists
            if (catalogName != null) {
                CatalogIf catalogIf = context.getEnv().getCatalogMgr().getCatalog(catalogName);
                if (catalogIf == null) {
                    context.getState()
                            .setError(ErrorCode.ERR_BAD_DB_ERROR, ErrorCode.ERR_BAD_DB_ERROR.formatErrorMsg(db));
                    return false;
                }
                if (catalogIf.getDbNullable(dbFullName) == null) {
                    context.getState()
                            .setError(ErrorCode.ERR_BAD_DB_ERROR, ErrorCode.ERR_BAD_DB_ERROR.formatErrorMsg(db));
                    return false;
                }
            }
            try {
                if (catalogName != null) {
                    context.getEnv().changeCatalog(context, catalogName);
                }
                Env.getCurrentEnv().changeDb(context, dbFullName);
            } catch (DdlException e) {
                context.getState().setError(e.getMysqlErrorCode(), e.getMessage());
                sendResponsePacket(context);
                return false;
            }
        }

        // set resource tag if has
        context.setComputeGroup(Env.getCurrentEnv().getAuth().getComputeGroup(qualifiedUser));
        return true;
    }

    public static byte readByte(ByteBuffer buffer) {
        return buffer.get();
    }

    public static byte readByteAt(ByteBuffer buffer, int index) {
        return buffer.get(index);
    }

    public static int readInt1(ByteBuffer buffer) {
        return readByte(buffer) & 0XFF;
    }

    public static int readInt2(ByteBuffer buffer) {
        return (readByte(buffer) & 0xFF) | ((readByte(buffer) & 0xFF) << 8);
    }

    public static int readInt3(ByteBuffer buffer) {
        return (readByte(buffer) & 0xFF) | ((readByte(buffer) & 0xFF) << 8) | ((readByte(
                buffer) & 0xFF) << 16);
    }

    public static int readLowestInt4(ByteBuffer buffer) {
        return (readByteAt(buffer, 0) & 0xFF) | ((readByteAt(buffer, 1) & 0xFF) << 8) | ((readByteAt(
                buffer, 2) & 0xFF) << 16) | ((readByteAt(buffer, 3) & 0XFF) << 24);
    }

    public static int readInt4(ByteBuffer buffer) {
        return (readByte(buffer) & 0xFF) | ((readByte(buffer) & 0xFF) << 8) | ((readByte(
                buffer) & 0xFF) << 16) | ((readByte(buffer) & 0XFF) << 24);
    }

    public static long readInt6(ByteBuffer buffer) {
        return (readInt4(buffer) & 0XFFFFFFFFL) | (((long) readInt2(buffer)) << 32);
    }

    public static long readInt8(ByteBuffer buffer) {
        return (readInt4(buffer) & 0XFFFFFFFFL) | (((long) readInt4(buffer)) << 32);
    }

    public static long readVInt(ByteBuffer buffer) {
        int b = readInt1(buffer);

        if (b < 251) {
            return b;
        }
        if (b == 252) {
            return readInt2(buffer);
        }
        if (b == 253) {
            return readInt3(buffer);
        }
        if (b == 254) {
            return readInt8(buffer);
        }
        if (b == 251) {
            throw new NullPointerException();
        }
        return 0;
    }

    public static byte[] readFixedString(ByteBuffer buffer, int len) {
        byte[] buf = new byte[len];
        buffer.get(buf);
        return buf;
    }

    public static byte[] readEofString(ByteBuffer buffer) {
        byte[] buf = new byte[buffer.remaining()];
        buffer.get(buf);
        return buf;
    }

    public static byte[] readLenEncodedString(ByteBuffer buffer) {
        long length = readVInt(buffer);
        byte[] buf = new byte[(int) length];
        buffer.get(buf);
        return buf;
    }

    public static byte[] readNulTerminateString(ByteBuffer buffer) {
        int oldPos = buffer.position();
        int nullPos = oldPos;
        for (nullPos = oldPos; nullPos < buffer.limit(); ++nullPos) {
            if (buffer.get(nullPos) == 0) {
                break;
            }
        }
        byte[] buf = new byte[nullPos - oldPos];
        buffer.get(buf);
        // skip null byte.
        buffer.get();
        return buf;
    }
}