MysqlChannel.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.mysql;
import org.apache.doris.common.ConnectionException;
import org.apache.doris.common.util.NetUtils;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.ConnectProcessor;
import com.google.common.base.Preconditions;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.jetbrains.annotations.NotNull;
import org.xnio.StreamConnection;
import org.xnio.channels.Channels;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
/**
* This class used to read/write MySQL logical packet.
* MySQL protocol will split one logical packet more than 16MB to many packets.
* http://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html
*/
public class MysqlChannel implements BytesChannel {
// logger for this class
private static final Logger LOG = LogManager.getLogger(MysqlChannel.class);
// max length which one MySQL physical can hold, if one logical packet is bigger than this,
// one packet will split to many packets
public static final int MAX_PHYSICAL_PACKET_LENGTH = 0xffffff;
// MySQL packet header length
protected static final int PACKET_HEADER_LEN = 4;
// SSL packet header length
protected static final int SSL_PACKET_HEADER_LEN = 5;
// next sequence id to receive or send
protected int sequenceId;
// channel connected with client
private StreamConnection conn;
// used to receive/send header, avoiding new this many time.
protected ByteBuffer headerByteBuffer;
protected ByteBuffer defaultBuffer;
protected ByteBuffer sslHeaderByteBuffer;
protected ByteBuffer tempBuffer;
protected ByteBuffer remainingBuffer;
protected ByteBuffer sendBuffer;
protected ByteBuffer decryptAppData;
protected ByteBuffer encryptNetData;
// for log and show
protected String remoteHostPortString;
protected String remoteIp;
protected boolean isSend;
protected boolean isSslMode;
protected boolean isSslHandshaking;
private SSLEngine sslEngine;
protected volatile MysqlSerializer serializer = MysqlSerializer.newInstance();
// mysql flag CLIENT_DEPRECATE_EOF
private boolean clientDeprecatedEOF;
// mysql flag CLIENT_MULTI_STATEMENTS
private boolean clientMultiStatements;
private ConnectContext context;
protected MysqlChannel() {
// For DummyMysqlChannel
}
public void setClientDeprecatedEOF() {
clientDeprecatedEOF = true;
}
public boolean clientDeprecatedEOF() {
return clientDeprecatedEOF;
}
public void setClientMultiStatements() {
clientMultiStatements = true;
}
public boolean clientMultiStatements() {
return clientMultiStatements;
}
public MysqlChannel(StreamConnection connection, ConnectContext context) {
Preconditions.checkNotNull(connection);
this.sequenceId = 0;
this.isSend = false;
this.remoteHostPortString = "";
this.remoteIp = "";
this.conn = connection;
// if proxy protocal is enabled, the remote address will be got from proxy protocal header
// and overwrite the original remote address.
if (connection.getPeerAddress() instanceof InetSocketAddress) {
InetSocketAddress address = (InetSocketAddress) connection.getPeerAddress();
remoteHostPortString = NetUtils
.getHostPortInAccessibleFormat(address.getHostString(), address.getPort());
remoteIp = address.getAddress().getHostAddress();
} else {
// Reach here, what's it?
remoteHostPortString = connection.getPeerAddress().toString();
remoteIp = connection.getPeerAddress().toString();
}
this.defaultBuffer = ByteBuffer.allocate(16 * 1024);
this.headerByteBuffer = ByteBuffer.allocate(PACKET_HEADER_LEN);
this.sendBuffer = ByteBuffer.allocate(2 * 1024 * 1024);
this.context = context;
}
public void initSslBuffer() {
// allocate buffer when needed.
this.remainingBuffer = ByteBuffer.allocate(16 * 1024);
this.remainingBuffer.flip();
this.tempBuffer = ByteBuffer.allocate(16 * 1024);
this.sslHeaderByteBuffer = ByteBuffer.allocate(SSL_PACKET_HEADER_LEN);
}
public void setSequenceId(int sequenceId) {
this.sequenceId = sequenceId;
}
public String getRemoteIp() {
return remoteIp;
}
public void setSslEngine(SSLEngine sslEngine) {
this.sslEngine = sslEngine;
decryptAppData = ByteBuffer.allocate(sslEngine.getSession().getApplicationBufferSize() * 2);
encryptNetData = ByteBuffer.allocate(sslEngine.getSession().getPacketBufferSize() * 2);
}
public void setSslMode(boolean sslMode) {
isSslMode = sslMode;
if (isSslMode) {
// channel in ssl mode means handshake phase has finished.
isSslHandshaking = false;
}
}
public void setSslHandshaking(boolean sslHandshaking) {
isSslHandshaking = sslHandshaking;
}
private int packetId() {
byte[] header = headerByteBuffer.array();
return header[3] & 0xFF;
}
private int packetLen(boolean isSslHeader) {
if (isSslHeader) {
byte[] header = sslHeaderByteBuffer.array();
return (header[4] & 0xFF) | ((header[3] & 0XFF) << 8);
} else {
byte[] header = headerByteBuffer.array();
return (header[0] & 0xFF) | ((header[1] & 0XFF) << 8) | ((header[2] & 0XFF) << 16);
}
}
private void accSequenceId() {
sequenceId++;
if (sequenceId > 255) {
sequenceId = 0;
}
}
// Close channel
public void close() {
try {
conn.close();
} catch (IOException e) {
LOG.warn("Close channel exception, ignore.");
}
}
// all packet header is not encrypted, packet body is not sure.
protected int readAll(ByteBuffer dstBuf, boolean isHeader) throws IOException {
int readLen = 0;
if (!dstBuf.hasRemaining()) {
return 0;
}
if (remainingBuffer != null && remainingBuffer.hasRemaining()) {
int oldLen = dstBuf.position();
while (dstBuf.hasRemaining()) {
dstBuf.put(remainingBuffer.get());
}
return dstBuf.position() - oldLen;
}
try {
while (dstBuf.remaining() != 0) {
int ret = Channels.readBlocking(conn.getSourceChannel(), dstBuf, context.getNetReadTimeout(),
TimeUnit.SECONDS);
// return -1 when remote peer close the channel
if (ret == -1) {
decryptData(dstBuf, isHeader);
return readLen;
}
readLen += ret;
}
decryptData(dstBuf, isHeader);
} catch (IOException e) {
if (LOG.isDebugEnabled()) {
LOG.debug("Read channel exception, ignore.", e);
}
return 0;
}
return readLen;
}
@Override
public int read(ByteBuffer dstBuf) {
int readLen = 0;
try {
while (dstBuf.remaining() != 0) {
int ret = Channels.readBlocking(conn.getSourceChannel(), dstBuf, context.getNetReadTimeout(),
TimeUnit.SECONDS);
// return -1 when remote peer close the channel
if (ret == -1) {
return 0;
}
readLen += ret;
}
} catch (IOException e) {
if (LOG.isDebugEnabled()) {
LOG.debug("Read channel exception, ignore.", e);
}
return 0;
}
return readLen;
}
@Override
public int testReadWithTimeout(ByteBuffer dstBuf, long timeoutMs) {
Preconditions.checkArgument(dstBuf.remaining() == 1, dstBuf.remaining());
try {
return Channels.readBlocking(conn.getSourceChannel(), dstBuf, timeoutMs, TimeUnit.MILLISECONDS);
} catch (IOException e) {
if (LOG.isDebugEnabled()) {
LOG.debug("Read channel exception, ignore.", e);
}
return -1;
}
}
protected void decryptData(ByteBuffer dstBuf, boolean isHeader) throws SSLException {
// after decrypt, we get a mysql packet with mysql header.
if (!isSslMode || isHeader) {
return;
}
dstBuf.flip();
decryptAppData.clear();
// unwrap will remove ssl header.
while (true) {
SSLEngineResult result = sslEngine.unwrap(dstBuf, decryptAppData);
if (handleUnwrapResult(result) && !dstBuf.hasRemaining()) {
break;
}
// if BUFFER_OVERFLOW or BUFFER_UNDERFLOW, need to unwrap again, so we do nothing.
}
decryptAppData.flip();
dstBuf.clear();
dstBuf.put(decryptAppData);
dstBuf.flip();
}
// read one logical mysql protocol packet
// null for channel is closed.
// NOTE: all of the following code is assumed that the channel is in block mode.
// if in handshaking mode we return a packet with header otherwise without header.
public ByteBuffer fetchOnePacket() throws IOException {
int readLen;
ByteBuffer result = defaultBuffer;
result.clear();
while (true) {
int packetLen;
// one SSL packet may include multiple Mysql packets, we use remainingBuffer to store them.
if ((isSslMode || isSslHandshaking) && !remainingBuffer.hasRemaining()) {
if (remainingBuffer.position() != 0) {
remainingBuffer.clear();
remainingBuffer.flip();
}
sslHeaderByteBuffer.clear();
readLen = readAll(sslHeaderByteBuffer, true);
if (readLen != SSL_PACKET_HEADER_LEN) {
// remote has close this channel
if (LOG.isDebugEnabled()) {
LOG.debug("Receive ssl packet header failed, remote may close the channel.");
}
return null;
}
// when handshaking and ssl mode, sslengine unwrap need a packet with header.
result.put(sslHeaderByteBuffer.array());
packetLen = packetLen(true);
} else {
headerByteBuffer.clear();
readLen = readAll(headerByteBuffer, true);
if (readLen != PACKET_HEADER_LEN) {
// remote has close this channel
if (LOG.isDebugEnabled()) {
LOG.debug("Receive packet header failed, remote may close the channel.");
}
return null;
}
if (packetId() != sequenceId) {
LOG.warn("receive packet sequence id[" + packetId() + "] want to get[" + sequenceId + "]");
throw new IOException("Bad packet sequence.");
}
packetLen = packetLen(false);
}
result = expandPacket(result, packetLen);
// read one physical packet
// before read, set limit to make read only one packet
result.limit(result.position() + packetLen);
readLen = readAll(result, false);
if (isSslMode && remainingBuffer.position() == 0 && result.hasRemaining()) {
byte[] header = result.array();
int packetId = header[3] & 0xFF;
if (packetId != sequenceId) {
LOG.warn("receive packet sequence id[" + packetId() + "] want to get[" + sequenceId + "]");
throw new IOException("Bad packet sequence.");
}
int mysqlPacketLength = (header[0] & 0xFF) | ((header[1] & 0XFF) << 8) | ((header[2] & 0XFF) << 16);
// remove mysql packet header
result.position(4);
result.compact();
// when encounter large sql query, one mysql packet will be packed as multiple ssl packets.
// we need to read all ssl packets to combine the complete mysql packet.
while (mysqlPacketLength > result.limit()) {
sslHeaderByteBuffer.clear();
readLen = readAll(sslHeaderByteBuffer, true);
if (readLen != SSL_PACKET_HEADER_LEN) {
// remote has close this channel
if (LOG.isDebugEnabled()) {
LOG.debug("Receive ssl packet header failed, remote may close the channel.");
}
return null;
}
tempBuffer.clear();
tempBuffer.put(sslHeaderByteBuffer.array());
packetLen = packetLen(true);
LOG.info("one ssl packet length is: " + packetLen);
tempBuffer = expandPacket(tempBuffer, packetLen);
result = expandPacket(result, tempBuffer.capacity());
// read one physical packet
// before read, set limit to make read only one packet
tempBuffer.limit(tempBuffer.position() + packetLen);
readLen = readAll(tempBuffer, false);
result.put(tempBuffer);
result.limit(result.position());
LOG.info("result is pos: " + result.position() + ", limit: "
+ result.limit() + "capacity: " + result.capacity());
}
if (mysqlPacketLength < result.position()) {
LOG.info("one SSL packet has multiple mysql packets.");
LOG.info("mysql packet length is " + mysqlPacketLength + ", result is pos: "
+ result.position() + ", limit: " + result.limit() + "capacity: " + result.capacity());
result.flip();
result.position(mysqlPacketLength);
remainingBuffer.clear();
remainingBuffer.put(result);
remainingBuffer.flip();
}
result.position(mysqlPacketLength);
}
if (readLen != packetLen) {
LOG.warn("Length of received packet content(" + readLen
+ ") is not equal with length in head.(" + packetLen + ")");
return null;
}
if (!isSslHandshaking) {
accSequenceId();
}
if (packetLen != MAX_PHYSICAL_PACKET_LENGTH) {
result.flip();
break;
}
}
return result;
}
@NotNull
private ByteBuffer expandPacket(ByteBuffer result, int packetLen) {
if ((result.capacity() - result.position()) < packetLen) {
// byte buffer is not enough, new one packet
ByteBuffer tmp;
if (packetLen < MAX_PHYSICAL_PACKET_LENGTH) {
// last packet, enough to this packet is OK.
tmp = ByteBuffer.allocate(packetLen + result.position());
} else {
// already have packet, to allocate two packet.
tmp = ByteBuffer.allocate(2 * packetLen + result.position());
}
tmp.put(result.array(), 0, result.position());
result = tmp;
}
result.limit(result.position() + packetLen);
return result;
}
protected void realNetSend(ByteBuffer buffer) throws IOException {
buffer = encryptData(buffer);
long bufLen = buffer.remaining();
long start = System.currentTimeMillis();
long writeLen = Channels.writeBlocking(conn.getSinkChannel(), buffer, context.getNetWriteTimeout(),
TimeUnit.SECONDS);
if (bufLen != writeLen) {
long duration = System.currentTimeMillis() - start;
throw new ConnectionException("Write mysql packet failed.[write=" + writeLen
+ ", needToWrite=" + bufLen + "], duration: " + duration + " ms");
}
Channels.flushBlocking(conn.getSinkChannel(), context.getNetWriteTimeout(), TimeUnit.SECONDS);
isSend = true;
}
protected ByteBuffer encryptData(ByteBuffer dstBuf) throws SSLException {
if (!isSslMode) {
return dstBuf;
}
encryptNetData.clear();
while (true) {
SSLEngineResult result = sslEngine.wrap(dstBuf, encryptNetData);
if (handleWrapResult(result) && !dstBuf.hasRemaining()) {
break;
}
}
encryptNetData.flip();
return encryptNetData;
}
public void flush() throws IOException {
if (null == sendBuffer || sendBuffer.position() == 0) {
// Nothing to send
return;
}
sendBuffer.flip();
try {
realNetSend(sendBuffer);
} finally {
sendBuffer.clear();
}
isSend = true;
}
private void writeHeader(int length, boolean isSsl) throws IOException {
if (null == sendBuffer) {
return;
}
long leftLength = sendBuffer.capacity() - sendBuffer.position();
if (leftLength < 4) {
flush();
}
long newLen = length;
for (int i = 0; i < 3; ++i) {
sendBuffer.put((byte) newLen);
newLen >>= 8;
}
sendBuffer.put((byte) sequenceId);
}
private void writeBuffer(ByteBuffer buffer) throws IOException {
if (null == sendBuffer) {
return;
}
// If too long for buffer, send buffered data.
if (sendBuffer.remaining() < buffer.remaining()) {
// Flush data in buffer.
flush();
}
// Send this buffer if large enough
if (buffer.remaining() > sendBuffer.remaining()) {
realNetSend(buffer);
return;
}
// Put it to
sendBuffer.put(buffer);
}
public void sendOnePacket(ByteBuffer packet) throws IOException {
// handshake in packet with header and has encrypted, need to send in ssl format
// ssl mode in packet no header and no encrypted, need to encrypted and add header and send in ssl format
int bufLen;
int oldLimit = packet.limit();
while (oldLimit - packet.position() >= MAX_PHYSICAL_PACKET_LENGTH) {
bufLen = MAX_PHYSICAL_PACKET_LENGTH;
packet.limit(packet.position() + bufLen);
if (isSslHandshaking) {
writeBuffer(packet);
} else {
writeHeader(bufLen, isSslMode);
writeBuffer(packet);
accSequenceId();
}
}
if (isSslHandshaking) {
packet.limit(oldLimit);
writeBuffer(packet);
} else {
writeHeader(oldLimit - packet.position(), isSslMode);
packet.limit(oldLimit);
writeBuffer(packet);
accSequenceId();
}
}
public void sendOnePacket(Object[] rows) throws IOException {
ByteBuffer packet;
serializer.reset();
for (Object value : rows) {
byte[] bytes = String.valueOf(value).getBytes();
serializer.writeVInt(bytes.length);
serializer.writeBytes(bytes);
}
packet = serializer.toByteBuffer();
sendOnePacket(packet);
}
public void sendAndFlush(ByteBuffer packet) throws IOException {
sendOnePacket(packet);
flush();
}
// Call this function before send query before
public void reset() {
isSend = false;
if (null != sendBuffer) {
sendBuffer.clear();
}
}
public boolean isSend() {
return isSend;
}
public String getRemoteHostPortString() {
return remoteHostPortString;
}
public void startAcceptQuery(ConnectContext connectContext, ConnectProcessor connectProcessor) {
conn.getSourceChannel().setReadListener(new ReadListener(connectContext, connectProcessor));
conn.getSourceChannel().resumeReads();
}
public void suspendAcceptQuery() {
conn.getSourceChannel().suspendReads();
}
public void resumeAcceptQuery() {
conn.getSourceChannel().resumeReads();
}
public void stopAcceptQuery() throws IOException {
conn.getSourceChannel().shutdownReads();
}
public MysqlSerializer getSerializer() {
return serializer;
}
private boolean handleWrapResult(SSLEngineResult sslEngineResult) throws SSLException {
switch (sslEngineResult.getStatus()) {
// normal status.
case OK:
return true;
case CLOSED:
sslEngine.closeOutbound();
return true;
case BUFFER_OVERFLOW:
// Could attempt to drain the serverNetData buffer of any already obtained
// data, but we'll just increase it to the size needed.
ByteBuffer newBuffer = ByteBuffer.allocate(encryptNetData.capacity() * 2);
encryptNetData.flip();
newBuffer.put(encryptNetData);
encryptNetData = newBuffer;
// retry the operation.
return false;
// when wrap BUFFER_UNDERFLOW and other status will not appear.
case BUFFER_UNDERFLOW:
default:
throw new IllegalStateException("invalid wrap status: " + sslEngineResult.getStatus());
}
}
private boolean handleUnwrapResult(SSLEngineResult sslEngineResult) {
switch (sslEngineResult.getStatus()) {
// normal status.
case OK:
return true;
case CLOSED:
sslEngine.closeOutbound();
return true;
case BUFFER_OVERFLOW:
// Could attempt to drain the clientAppData buffer of any already obtained
// data, but we'll just increase it to the size needed.
ByteBuffer newAppBuffer = ByteBuffer.allocate(decryptAppData.capacity() * 2);
decryptAppData.flip();
newAppBuffer.put(decryptAppData);
decryptAppData = newAppBuffer;
// retry the operation.
return false;
case BUFFER_UNDERFLOW:
default:
throw new IllegalStateException("invalid wrap status: " + sslEngineResult.getStatus());
}
}
// for proxy protocal only
public void setRemoteAddr(String ip, int port) {
this.remoteIp = ip;
this.remoteHostPortString = NetUtils.getHostPortInAccessibleFormat(ip, port);
}
}