ProxyProtocolHandler.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.mysql;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
/**
* Proxy protocol handler.
* The proxy protocol is a simple protocol to pass client connection information to the server.
* It is used in some load balancers and proxies to pass the client's IP address and port to the server.
* The protocol is defined in https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt
* The protocol has two versions: V1 and V2.
* V1 is a text-based protocol, and V2 is a binary protocol.
* This class only supports V1.
* The V1 protocol is a text-based protocol, and the header is "PROXY ".
* The protocol is defined as:
* PROXY TCP4[TCP6] <srcip> <dstip> <srcport> <dstport>\r\n
* or
* PROXY UNKNOWN xxxx\r\n
*/
public class ProxyProtocolHandler {
private static final Logger LOG = LogManager.getLogger(ProxyProtocolHandler.class);
private static final byte[] V1_HEADER = "PROXY ".getBytes(StandardCharsets.US_ASCII);
private static final byte[] V2_HEADER
= new byte[] {0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A};
private static final String UNKNOWN = "UNKNOWN";
private static final String TCP4 = "TCP4";
private static final String TCP6 = "TCP6";
public enum ProtocolType {
PROTOCOL_WITH_IP, // protocol with source ip
PROTOCOL_WITHOUT_IP, // v2 protocol without source ip
NOT_PROXY_PROTOCOL // not proxy protocol
}
public static class ProxyProtocolResult {
public String sourceIP = null;
public int sourcePort = -1;
public String destIp = null;
public int destPort = -1;
public ProtocolType pType = ProtocolType.PROTOCOL_WITH_IP;
@Override
public String toString() {
return "ProxyProtocolResult{"
+ "sourceIP='" + sourceIP + '\''
+ ", sourcePort=" + sourcePort
+ ", destIp='" + destIp + '\''
+ ", destPort=" + destPort
+ ", pType=" + pType
+ '}';
}
}
public static ProxyProtocolResult handle(BytesChannel channel) throws IOException {
// First read 1 byte to see if it is V1 or V2
ByteBuffer buffer = ByteBuffer.allocate(1);
int readLen = channel.testReadWithTimeout(buffer, 10);
if (readLen == -1) {
throw new IOException("Remote peer closed the channel, ignore.");
} else if (readLen == 0) {
// 0 means remote peer does not send proxy protocol content.
ProxyProtocolResult result = new ProxyProtocolResult();
result.pType = ProtocolType.NOT_PROXY_PROTOCOL;
return result;
} else if (readLen != 1) {
throw new IOException("Invalid proxy protocol, expect incoming bytes first");
}
buffer.flip();
byte firstByte = buffer.get();
if ((char) firstByte == V1_HEADER[0]) {
return handleV1(channel);
} else if (firstByte == V2_HEADER[0]) {
return handleV2(channel);
} else {
throw new IOException("Invalid proxy protocol header in first bytes: " + firstByte + ".");
}
}
private static ProxyProtocolResult handleV1(BytesChannel channel) throws IOException {
ProxyProtocolResult result = new ProxyProtocolResult();
int byteCount = 1; // already read the first byte, so start with 1
boolean parsingUnknown = false; // true if "UNKNOWN" is found
boolean carriageFound = false; // true if \r is found
String protocol = null;
StringBuilder stringBuilder = new StringBuilder();
// read last 5 bytes of "PROXY "
ByteBuffer buffer = ByteBuffer.allocate(5);
int readLen = channel.read(buffer);
if (readLen != 5) {
throw new IOException("Invalid proxy protocol v1, expected \"PROXY \"");
}
byteCount += readLen;
StringBuilder debugInfo = new StringBuilder("PROXY ");
// start reading
buffer = ByteBuffer.allocate(1);
channel.read(buffer);
buffer.flip();
while (buffer.hasRemaining()) {
char c = (char) buffer.get();
debugInfo.append(c);
if (parsingUnknown) {
// Found "PROXY UNKNOWN"
// ignore any other bytes until "\r\n"
if (c == '\r') {
carriageFound = true;
} else if (c == '\n') {
if (!carriageFound) {
throw new ProtocolException("Invalid proxy protocol v1. '\\r' is not found before '\\n'",
debugInfo.toString());
}
result.pType = ProtocolType.PROTOCOL_WITHOUT_IP;
return result;
} else if (carriageFound) {
throw new ProtocolException("Invalid proxy protocol v1. "
+ "'\\r' should follow with '\\n', but see: " + c + ".", debugInfo.toString());
}
} else if (carriageFound) {
if (c == '\n') {
// eof, set remote ip
if (LOG.isDebugEnabled()) {
LOG.debug("Finish parsing proxy protocol v1. result: {}", result);
}
return result;
} else {
throw new ProtocolException("Invalid proxy protocol v1. "
+ "'\\r' should follow with '\\n', but see: " + c + ".", debugInfo.toString());
}
} else {
switch (c) {
case ' ':
if (result.sourcePort != -1 || stringBuilder.length() == 0) {
throw new ProtocolException("Invalid proxy protocol v1. expecting a '\\r' or a '\\n'",
debugInfo.toString());
} else if (protocol == null) {
protocol = stringBuilder.toString();
stringBuilder.setLength(0);
if (protocol.equals(UNKNOWN)) {
parsingUnknown = true;
} else if (!protocol.equals(TCP4) && !protocol.equals(TCP6)) {
throw new ProtocolException("Invalid proxy protocol v1. expecting TCP4/TCP6/UNKNOWN."
+ " See: " + protocol + ".", debugInfo.toString());
}
} else if (result.sourceIP == null) {
result.sourceIP = stringBuilder.toString();
stringBuilder.setLength(0);
} else if (result.destIp == null) {
result.destIp = stringBuilder.toString();
stringBuilder.setLength(0);
} else {
result.sourcePort = Integer.parseInt(stringBuilder.toString());
stringBuilder.setLength(0);
}
break;
case '\r':
if (result.destPort == -1 && result.sourcePort != -1
&& !carriageFound && stringBuilder.length() > 0) {
result.destPort = Integer.parseInt(stringBuilder.toString());
stringBuilder.setLength(0);
carriageFound = true;
} else if (protocol == null) {
if (UNKNOWN.equals(stringBuilder.toString())) {
parsingUnknown = true;
carriageFound = true;
}
} else {
throw new ProtocolException(
"Invalid proxy protocol v1. Already see '\\r' but no valid info",
debugInfo.toString());
}
break;
case '\n':
throw new ProtocolException("Invalid proxy protocol v1. '\\r' is not found before '\\n'",
debugInfo.toString());
default:
stringBuilder.append(c);
}
}
byteCount++;
if (byteCount == 107) {
throw new ProtocolException("Invalid proxy protocol v1, max length(107) exceeds",
debugInfo.toString());
} else {
buffer.clear();
channel.read(buffer);
buffer.flip();
}
}
throw new ProtocolException("Invalid proxy protocol v1, unexpected end of stream", debugInfo.toString());
}
private static ProxyProtocolResult handleV2(BytesChannel channel) throws IOException {
throw new IOException("proxy protocol v2 is not supported yet");
}
public static class ProtocolException extends IOException {
public ProtocolException(String message, String protocolStr) {
super(message + ": " + protocolStr);
}
}
}