AuthenticationIntegrationRuntime.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.authentication;

import org.apache.doris.authentication.handler.AuthenticationOutcome;
import org.apache.doris.authentication.handler.AuthenticationPluginManager;
import org.apache.doris.authentication.spi.AuthenticationPlugin;
import org.apache.doris.common.Config;

import com.google.common.base.Strings;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.io.Closeable;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;

/**
 * Runtime manager for AUTHENTICATION INTEGRATION.
 */
public class AuthenticationIntegrationRuntime {
    private static final Logger LOG = LogManager.getLogger(AuthenticationIntegrationRuntime.class);

    public enum RuntimeState {
        AVAILABLE,
        BROKEN
    }

    public static final class PreparedAuthenticationIntegration implements Closeable {
        private final AuthenticationIntegration integration;
        private final AuthenticationPlugin plugin;

        private PreparedAuthenticationIntegration(AuthenticationIntegration integration, AuthenticationPlugin plugin) {
            this.integration = Objects.requireNonNull(integration, "integration");
            this.plugin = Objects.requireNonNull(plugin, "plugin");
        }

        public AuthenticationIntegration getIntegration() {
            return integration;
        }

        public AuthenticationPlugin getPlugin() {
            return plugin;
        }

        @Override
        public void close() throws IOException {
            plugin.close();
        }
    }

    private final AuthenticationPluginManager pluginManager;
    private final Map<String, RuntimeState> runtimeStates = new ConcurrentHashMap<>();
    private final Map<String, String> brokenReasons = new ConcurrentHashMap<>();

    public AuthenticationIntegrationRuntime() {
        this(new AuthenticationPluginManager());
    }

    public AuthenticationIntegrationRuntime(AuthenticationPluginManager pluginManager) {
        this.pluginManager = Objects.requireNonNull(pluginManager, "pluginManager");
    }

    public PreparedAuthenticationIntegration prepareAuthenticationIntegration(AuthenticationIntegrationMeta meta)
            throws AuthenticationException {
        AuthenticationIntegration integration = toIntegration(meta);
        ensurePluginFactoryLoaded(integration.getType());
        AuthenticationPlugin plugin = pluginManager.createPlugin(integration);
        return new PreparedAuthenticationIntegration(integration, plugin);
    }

    public void activatePreparedAuthenticationIntegration(PreparedAuthenticationIntegration prepared) {
        pluginManager.installPlugin(prepared.getIntegration(), prepared.getPlugin());
        runtimeStates.put(prepared.getIntegration().getName(), RuntimeState.AVAILABLE);
        brokenReasons.remove(prepared.getIntegration().getName());
    }

    public void discardPreparedAuthenticationIntegration(PreparedAuthenticationIntegration prepared) {
        if (prepared == null) {
            return;
        }
        try {
            prepared.close();
        } catch (IOException ignored) {
            // AuthenticationPlugin.close() does not throw. This is only to satisfy Closeable.
        }
    }

    public void removeAuthenticationIntegration(String integrationName) {
        pluginManager.removePlugin(integrationName);
        runtimeStates.remove(integrationName);
        brokenReasons.remove(integrationName);
    }

    public void replayUpsertAuthenticationIntegration(AuthenticationIntegrationMeta meta) {
        pluginManager.removePlugin(meta.getName());
        PreparedAuthenticationIntegration prepared = null;
        try {
            prepared = prepareAuthenticationIntegration(meta);
            activatePreparedAuthenticationIntegration(prepared);
        } catch (AuthenticationException e) {
            markBroken(meta.getName(), e);
        } finally {
            if (prepared != null && runtimeStates.get(meta.getName()) != RuntimeState.AVAILABLE) {
                discardPreparedAuthenticationIntegration(prepared);
            }
        }
    }

    public void rebuildAuthenticationIntegrations(Map<String, AuthenticationIntegrationMeta> snapshot) {
        pluginManager.clearCache();
        runtimeStates.clear();
        brokenReasons.clear();
        for (AuthenticationIntegrationMeta meta : snapshot.values()) {
            replayUpsertAuthenticationIntegration(meta);
        }
    }

    public AuthenticationOutcome authenticate(List<AuthenticationIntegrationMeta> chain, AuthenticationRequest request)
            throws AuthenticationException {
        Objects.requireNonNull(chain, "chain");
        Objects.requireNonNull(request, "request");
        if (chain.isEmpty()) {
            throw new AuthenticationException(
                    "authentication chain is empty",
                    AuthenticationFailureType.MISCONFIGURED);
        }

        AuthenticationOutcome lastFailure = null;
        boolean anySupported = false;
        for (AuthenticationIntegrationMeta meta : chain) {
            AuthenticationIntegration integration = toIntegration(meta);
            ensurePluginFactoryLoaded(integration.getType());
            AuthenticationPlugin plugin;
            try {
                plugin = pluginManager.getPlugin(integration);
            } catch (AuthenticationException e) {
                AuthenticationResult result = AuthenticationResult.failure(e);
                AuthenticationOutcome outcome = AuthenticationOutcome.of(integration, result);
                lastFailure = outcome;
                if (!shouldContinueInChain(result)) {
                    return outcome;
                }
                continue;
            }
            if (!plugin.supports(request)) {
                continue;
            }
            anySupported = true;

            AuthenticationResult result;
            try {
                result = plugin.authenticate(request, integration);
            } catch (AuthenticationException e) {
                result = AuthenticationResult.failure(e);
            }

            AuthenticationOutcome outcome = AuthenticationOutcome.of(integration, result);
            if (!outcome.isFailure()) {
                return outcome;
            }
            lastFailure = outcome;
            if (!shouldContinueInChain(result)) {
                return outcome;
            }
        }

        if (lastFailure != null) {
            return lastFailure;
        }
        if (!anySupported) {
            throw new AuthenticationException(
                    "No authentication integration supports request for user: " + request.getUsername(),
                    AuthenticationFailureType.MISCONFIGURED);
        }
        throw new AuthenticationException(
                "Authentication failed for user: " + request.getUsername(),
                AuthenticationFailureType.ACCESS_DENIED);
    }

    public RuntimeState getRuntimeState(String integrationName) {
        return runtimeStates.get(integrationName);
    }

    public String getBrokenReason(String integrationName) {
        return brokenReasons.get(integrationName);
    }

    private void ensurePluginFactoryLoaded(String pluginType) throws AuthenticationException {
        if (pluginManager.hasFactory(pluginType)) {
            return;
        }

        try {
            Path pluginRoot = Paths.get(Config.authentication_plugins_dir);
            pluginManager.loadAll(Collections.singletonList(pluginRoot), getClass().getClassLoader());
        } catch (AuthenticationException e) {
            throw new AuthenticationException(
                    "Failed to load authentication plugins for type '" + pluginType + "': " + e.getMessage(),
                    e,
                    AuthenticationFailureType.MISCONFIGURED);
        }

        if (!pluginManager.hasFactory(pluginType)) {
            throw new AuthenticationException(
                    "No authentication plugin factory found for type: " + pluginType,
                    AuthenticationFailureType.MISCONFIGURED);
        }
    }

    private void markBroken(String integrationName, AuthenticationException exception) {
        runtimeStates.put(integrationName, RuntimeState.BROKEN);
        brokenReasons.put(integrationName, Strings.nullToEmpty(exception.getMessage()));
        LOG.warn("Authentication integration '{}' is broken: {}", integrationName, exception.getMessage(), exception);
    }

    private static boolean shouldContinueInChain(AuthenticationResult result) {
        if (!result.isFailure()) {
            return false;
        }
        AuthenticationException exception = result.getException();
        return exception != null && exception.getFailureType().shouldContinueInChain();
    }

    private static AuthenticationIntegration toIntegration(AuthenticationIntegrationMeta meta) {
        return AuthenticationIntegration.builder()
                .name(meta.getName())
                .type(meta.getType())
                .properties(meta.getProperties())
                .comment(meta.getComment())
                .build();
    }
}