FlightRemoteIpServerStreamTracer.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.service.arrowflight.auth2;

import io.grpc.Attributes;
import io.grpc.Context;
import io.grpc.Grpc;
import io.grpc.Metadata;
import io.grpc.ServerStreamTracer;

import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;

/**
 * Captures the gRPC peer address before Arrow Flight header authentication runs.
 * Arrow registers header authentication ahead of user interceptors, so use ServerStreamTracer to
 * seed the remote IP into the gRPC Context for Basic credential validation.
 */
public class FlightRemoteIpServerStreamTracer extends ServerStreamTracer {
    static final String UNKNOWN_REMOTE_IP = "0.0.0.0";
    private static final Context.Key<RemoteIpHolder> REMOTE_IP_CONTEXT_KEY =
            Context.key("doris.arrow.flight.remote_ip");

    @Override
    public Context filterContext(Context context) {
        return context.withValue(REMOTE_IP_CONTEXT_KEY, new RemoteIpHolder());
    }

    @Override
    public void serverCallStarted(ServerCallInfo<?, ?> callInfo) {
        RemoteIpHolder holder = REMOTE_IP_CONTEXT_KEY.get();
        if (holder == null) {
            return;
        }

        Attributes attributes = callInfo.getAttributes();
        SocketAddress remoteAddress = attributes == null ? null : attributes.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR);
        holder.setRemoteIp(extractRemoteIp(remoteAddress));
    }

    public static String getRemoteIp() {
        RemoteIpHolder holder = REMOTE_IP_CONTEXT_KEY.get();
        if (holder == null) {
            return UNKNOWN_REMOTE_IP;
        }
        return holder.getRemoteIp();
    }

    static String extractRemoteIp(SocketAddress remoteAddress) {
        if (!(remoteAddress instanceof InetSocketAddress)) {
            return UNKNOWN_REMOTE_IP;
        }

        InetSocketAddress inetSocketAddress = (InetSocketAddress) remoteAddress;
        InetAddress address = inetSocketAddress.getAddress();
        if (address != null && isNotEmpty(address.getHostAddress())) {
            return address.getHostAddress();
        }
        if (isNotEmpty(inetSocketAddress.getHostString())) {
            return inetSocketAddress.getHostString();
        }
        return UNKNOWN_REMOTE_IP;
    }

    private static boolean isNotEmpty(String value) {
        return value != null && !value.isEmpty();
    }

    public static class Factory extends ServerStreamTracer.Factory {
        @Override
        public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata headers) {
            return new FlightRemoteIpServerStreamTracer();
        }
    }

    private static class RemoteIpHolder {
        private volatile String remoteIp = UNKNOWN_REMOTE_IP;

        String getRemoteIp() {
            return remoteIp;
        }

        void setRemoteIp(String remoteIp) {
            this.remoteIp = isNotEmpty(remoteIp) ? remoteIp : UNKNOWN_REMOTE_IP;
        }
    }
}