LLMResource.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.catalog;

import org.apache.doris.common.DdlException;
import org.apache.doris.common.proc.BaseProcResult;
import org.apache.doris.datasource.property.constants.LLMProperties;
import org.apache.doris.thrift.TLLMResource;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.gson.annotations.SerializedName;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;

/**
 * LLM Resource
 * <p>
 * Syntax:
 * CREATE RESOURCE "deepseek-chat"
 * PROPERTIES
 * (
 * 'type' = 'llm',
 * 'llm.provider_type' = 'deepseek',
 * 'llm.endpoint' = 'https://api.deepseek.com/chat/completions',
 * 'llm.model_name' = 'deepseek-chat',
 * 'llm.api_key' = 'sk-xxx',
 * 'llm.temperature' = '0.7',
 * 'llm.max_token' = '1024',
 * 'llm.max_retries' = '3',
 * 'llm.retry_delay_second' = '1'
 * );
 * <p>
 */

public class LLMResource extends Resource {
    private static final Logger LOG = LogManager.getLogger(LLMResource.class);
    @SerializedName(value = "properties")
    private Map<String, String> properties;

    public LLMResource() {
        super();
    }

    public LLMResource(String name) {
        super(name, ResourceType.LLM);
        properties = Maps.newHashMap();
    }

    @Override
    protected void setProperties(ImmutableMap<String, String> newProperties) throws DdlException {
        Preconditions.checkState(newProperties != null);
        this.properties = Maps.newHashMap(newProperties);

        LLMProperties.requiredLLMProperties(properties);

        boolean needCheck = isNeedCheck(properties);
        if (LOG.isDebugEnabled()) {
            LOG.debug("LLM resource need check validity: {}", needCheck);
        }
        if (needCheck) {
            pingLLM(properties);
        }

        LLMProperties.optionalLLMProperties(this.properties);
    }

    protected static void pingLLM(Map<String, String> properties) throws DdlException {
        try {
            HttpURLConnection connection = getHttpURLConnection(properties);

            int responseCode = connection.getResponseCode();

            if (responseCode == HttpURLConnection.HTTP_OK) {
                LOG.info("Successfully connected to LLM API at {}", properties.get(LLMProperties.ENDPOINT));
            } else {
                StringBuilder response = new StringBuilder();
                try (BufferedReader br = new BufferedReader(
                        new InputStreamReader(connection.getErrorStream(), StandardCharsets.UTF_8))) {
                    String responseLine;
                    while ((responseLine = br.readLine()) != null) {
                        response.append(responseLine.trim());
                    }
                }

                throw new DdlException("Failed to connect to LLM API: HTTP " + responseCode
                        + ". Response: " + response);
            }

        } catch (IOException e) {
            throw new DdlException("Failed to connect to LLM API: " + e.getMessage());
        }
    }

    private static HttpURLConnection getHttpURLConnection(Map<String, String> properties) throws IOException {
        String endpoint = properties.get(LLMProperties.ENDPOINT);
        String providerType = properties.get(LLMProperties.PROVIDER_TYPE).toLowerCase();
        String modelName = properties.get(LLMProperties.MODEL_NAME);
        String apiKey = properties.get(LLMProperties.API_KEY);
        String anthropicVersion = properties.get(LLMProperties.ANTHROPIC_VERSION);

        URL url = new URL(endpoint);
        HttpURLConnection connection = (HttpURLConnection) url.openConnection();
        connection.setRequestMethod("POST");
        connection.setRequestProperty("Content-Type", "application/json");

        if ("gemini".equalsIgnoreCase(providerType)) {
            connection.setRequestProperty("x-goog-api-key", apiKey);
            connection.setRequestProperty("Content-Type", "application/json");
        } else if ("anthropic".equalsIgnoreCase(providerType)) {
            connection.setRequestProperty("x-api-key", apiKey);
            connection.setRequestProperty("anthropic-version", anthropicVersion);
            connection.setRequestProperty("Content-Type", "application/json");
        } else if (!"local".equalsIgnoreCase(providerType)) {
            connection.setRequestProperty("Content-Type", "application/json");
            connection.setRequestProperty("Authorization", "Bearer " + apiKey);
        }

        connection.setDoOutput(true);
        connection.setConnectTimeout(10000);
        connection.setReadTimeout(10000);

        String testPrompt = buildTestPrompt(providerType, modelName);

        try (OutputStream os = connection.getOutputStream()) {
            byte[] input = testPrompt.getBytes(StandardCharsets.UTF_8);
            os.write(input, 0, input.length);
        }
        return connection;
    }

    private static String buildTestPrompt(String providerType, String modelName) {
        switch (providerType) {
            case "openai":
            case "deepseek":
            case "moonshot":
            case "zhipu":
            case "qwen":
            case "minimax":
            case "local":
                return "{\"model\":\"" + modelName + "\",\"messages\":[{\"role\":\"user\",\"content\":\"Hello\"}],"
                        + "\"max_tokens\":5}";
            case "anthropic":
                return "{\"model\":\"" + modelName + ",\"max_tokens\":5,\"messages\":[{\"role\":\"user\",\"content\":"
                        + "[{\"type\":\"text\",\"text\":\"Hello\"}]}]}";
            case "gemini":
                return "{\"contents\":[{\"parts\":[{\"text\":\"Hello\"}]}],\"generationConfig\":"
                    + "{\"maxOutputTokens\":5}}";
            default:
                return "{\"messages\":[{\"role\":\"user\",\"content\":\"Hello\"}],"
                        + "\"max_tokens\":5}";
        }
    }

    public String getProperty(String propertyKey) {
        return properties.get(propertyKey);
    }

    private boolean isNeedCheck(Map<String, String> newProperties) {
        boolean needCheck = !this.properties.containsKey(LLMProperties.VALIDITY_CHECK)
                || Boolean.parseBoolean(this.properties.get(LLMProperties.VALIDITY_CHECK));

        if (newProperties != null && newProperties.containsKey(LLMProperties.VALIDITY_CHECK)) {
            needCheck = Boolean.parseBoolean(newProperties.get(LLMProperties.VALIDITY_CHECK));
        }

        if ("LOCAL".equalsIgnoreCase(this.properties.getOrDefault(LLMProperties.PROVIDER_TYPE, ""))) {
            needCheck = false;
        }
        return needCheck;
    }

    @Override
    public void modifyProperties(Map<String, String> properties) throws DdlException {
        boolean needCheck = isNeedCheck(properties);
        if (LOG.isDebugEnabled()) {
            LOG.debug("LLM resource need check validity: {}", needCheck);
        }

        if (needCheck) {
            Map<String, String> changedProperties = new HashMap<>(this.properties);
            changedProperties.putAll(properties);
            LLMProperties.requiredLLMProperties(changedProperties);
            pingLLM(changedProperties);
        }

        // modify properties
        writeLock();
        for (Map.Entry<String, String> kv : properties.entrySet()) {
            replaceIfEffectiveValue(this.properties, kv.getKey(), kv.getValue());
            if (kv.getKey().equals(LLMProperties.API_KEY)) {
                this.properties.put(kv.getKey(), kv.getValue());
            }
        }
        ++version;
        writeUnlock();
        super.modifyProperties(properties);
    }

    @Override
    public Map<String, String> getCopiedProperties() {
        return Maps.newHashMap(properties);
    }

    @Override
    protected void getProcNodeData(BaseProcResult result) {
        String lowerCaseType = type.name().toLowerCase();
        result.addRow(Lists.newArrayList(name, lowerCaseType, "id", String.valueOf(id)));
        readLock();
        result.addRow(Lists.newArrayList(name, lowerCaseType, "version", String.valueOf(version)));
        for (Map.Entry<String, String> entry : properties.entrySet()) {
            if (entry.getKey().equals(LLMProperties.API_KEY)) {
                result.addRow(Lists.newArrayList(name, lowerCaseType, entry.getKey(), "******"));
            } else {
                result.addRow(Lists.newArrayList(name, lowerCaseType, entry.getKey(), entry.getValue()));
            }
        }
        readUnlock();
    }

    public TLLMResource toThrift() throws NumberFormatException {
        TLLMResource tLLMResource = new TLLMResource();
        tLLMResource.setProviderType(properties.get(LLMProperties.PROVIDER_TYPE));
        tLLMResource.setEndpoint(properties.get(LLMProperties.ENDPOINT));
        tLLMResource.setApiKey(properties.get(LLMProperties.API_KEY));
        tLLMResource.setModelName(properties.get(LLMProperties.MODEL_NAME));
        tLLMResource.setAnthropicVersion(properties.get(LLMProperties.ANTHROPIC_VERSION));

        try {
            tLLMResource.setTemperature(Double.parseDouble(properties.get(LLMProperties.TEMPERATURE)));
        } catch (NumberFormatException e) {
            throw new NumberFormatException("Failed to parse temperature: "
                                            + properties.get(LLMProperties.TEMPERATURE));
        }
        try {
            tLLMResource.setMaxTokens(Long.parseLong(properties.get(LLMProperties.MAX_TOKEN)));
        } catch (NumberFormatException e) {
            throw new NumberFormatException("Failed to parse max_token: "
                                            + properties.get(LLMProperties.MAX_TOKEN));
        }
        try {
            tLLMResource.setMaxRetries(Long.parseLong(properties.get(LLMProperties.MAX_RETRIES)));
        } catch (NumberFormatException e) {
            throw new NumberFormatException("Failed to parse max_retries: "
                                            + properties.get(LLMProperties.MAX_RETRIES));
        }
        try {
            tLLMResource.setRetryDelaySecond(Long.parseLong(properties.get(LLMProperties.RETRY_DELAY_SECOND)));
        } catch (NumberFormatException e) {
            throw new NumberFormatException("Failed to parse retry_delay_second: "
                                            + properties.get(LLMProperties.RETRY_DELAY_SECOND));
        }

        return tLLMResource;
    }
}