/*
 * Decompiled with CFR 0.152.
 */
package com.ontotext.graphdb.gpt;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.google.common.annotations.VisibleForTesting;
import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.EncodingRegistry;
import com.knuddels.jtokkit.api.EncodingType;
import com.knuddels.jtokkit.api.ModelType;
import com.ontotext.graphdb.Config;
import com.ontotext.graphdb.configs.LLMConfig;
import com.ontotext.graphdb.configs.SystemConfig;
import com.ontotext.graphdb.gpt.ChatModelFactory;
import com.ontotext.graphdb.gpt.GptClient;
import com.ontotext.graphdb.gpt.quota.GraphDBTokenQuotaManager;
import com.ontotext.graphdb.gpt.quota.QuotaUtils;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.exception.AuthenticationException;
import dev.langchain4j.model.chat.ChatModel;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.output.TokenUsage;
import java.io.IOException;
import java.net.SocketTimeoutException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.NoSuchElementException;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.lang3.StringUtils;
import org.eclipse.rdf4j.query.QueryEvaluationException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GptDefaultClient
implements GptClient {
    public static final String UNSUPPORTED_LLM_MODEL_TYPE_WARNING = "Unsupported LLM model type: '%s'. Using default values";
    public static final String GPT_MODEL_PROPERTY = "graphdb.gpt-sparql.model";
    public static final String GPT_API_VERSION = "graphdb.gpt-sparql.api-version";
    public static final String LEGACY_GPT_MODEL_PROPERTY = "graphdb.gpt.model";
    private final Logger logger = LoggerFactory.getLogger(GptDefaultClient.class);
    private final ObjectMapper objectMapper = new ObjectMapper();
    private final String llmModel = this.getLLMFunctionsModel();
    private int llmModelMaxTokens;
    private Encoding gptEncoding;
    private RequestExecutor requestExecutor;

    @VisibleForTesting
    GptDefaultClient() {
        EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();
        try {
            ModelType modelType = StringUtils.isNotEmpty((CharSequence)LLMConfig.getLLMUrl()) && ModelType.fromName((String)this.llmModel).isEmpty() ? ModelType.GPT_4O : (ModelType)ModelType.fromName((String)this.llmModel).orElseThrow();
            this.setupLLMModelConfiguration(modelType, registry);
        }
        catch (NoSuchElementException noSuchElementException) {
            try {
                String adjustedModel = this.sanitizeModel();
                ModelType modelType = (ModelType)ModelType.fromName((String)adjustedModel).orElseThrow();
                this.setupLLMModelConfiguration(modelType, registry);
            }
            catch (NoSuchElementException noSuchElementException2) {
                String error = String.format(UNSUPPORTED_LLM_MODEL_TYPE_WARNING, this.llmModel);
                this.logger.warn(error);
                this.llmModelMaxTokens = 64000;
                this.gptEncoding = registry.getEncoding(EncodingType.R50K_BASE);
            }
        }
        this.objectMapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
        this.objectMapper.setPropertyNamingStrategy(PropertyNamingStrategies.SNAKE_CASE);
        this.objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
        this.requestExecutor = this::defaultChatCompletionRequestExecute;
    }

    @Override
    public String chat(List<String> systemMessages, List<String> userMessages, List<String> userMessageLines, Double temperature) {
        this.logger.debug("Calling GPT...");
        ArrayList<UserMessage> messages = new ArrayList<UserMessage>();
        AtomicInteger numSystemTokens = new AtomicInteger();
        systemMessages.forEach(sm -> {
            this.logger.debug("SYSTEM: {}", sm);
            messages.add((UserMessage)new SystemMessage(sm));
            numSystemTokens.addAndGet(this.numTokens((String)sm));
        });
        this.logger.debug("--- System message tokens: {}", (Object)numSystemTokens);
        AtomicInteger numUserTokens = new AtomicInteger();
        userMessages.forEach(sm -> {
            this.logger.debug("USER: {}", sm);
            messages.add(new UserMessage(sm));
            numUserTokens.addAndGet(this.numTokens((String)sm));
        });
        this.logger.debug("--- User message tokens: {}", (Object)numSystemTokens);
        if (!userMessageLines.isEmpty()) {
            int remainingUserLimit = this.llmModelMaxTokens - numSystemTokens.get() - numUserTokens.get() - 1000;
            String userMessage = this.getUserMessageWithLimit(remainingUserLimit, userMessageLines);
            messages.add(new UserMessage(userMessage));
        }
        return this.chat(messages, this.llmModel, temperature);
    }

    @Override
    public String chat(List<? extends ChatMessage> messages, String model, Double temperature) {
        GraphDBTokenQuotaManager.getInstance().validateQuotaOrThrow();
        ChatRequest chatCompletionRequest = ChatRequest.builder().modelName(model).temperature(temperature).messages(messages).build();
        ChatResponse chatCompletionResult = this.executeChatCompletionRequest(chatCompletionRequest);
        String result = chatCompletionResult.aiMessage().text();
        this.logger.debug("GPT response: {}", (Object)result);
        return result;
    }

    public void setMockRequestExecutor(RequestExecutor requestExecutor) {
        this.requestExecutor = requestExecutor != null ? requestExecutor : this::defaultChatCompletionRequestExecute;
    }

    private int numTokens(String message) {
        return this.gptEncoding.countTokens(message);
    }

    private String getUserMessageWithLimit(int remainingTokenLimit, List<String> userMessageLines) {
        int userMessageLimit;
        int originalRemainingUserLimit = remainingTokenLimit;
        for (userMessageLimit = 0; userMessageLimit < userMessageLines.size() && remainingTokenLimit >= 0; remainingTokenLimit -= this.numTokens(userMessageLines.get(userMessageLimit) + "\n"), ++userMessageLimit) {
        }
        while (userMessageLimit > 0) {
            List<String> reducedUserMessageLines = userMessageLines.subList(0, userMessageLimit);
            String userMessage = String.join((CharSequence)"\n", reducedUserMessageLines);
            int numUserTokens = this.numTokens(userMessage);
            if (numUserTokens > originalRemainingUserLimit) {
                --userMessageLimit;
                continue;
            }
            if (userMessageLimit < userMessageLines.size()) {
                this.logger.warn("Using {} of {} user message lines to fit token limit", (Object)userMessageLimit, (Object)userMessageLines.size());
            }
            this.logger.debug("USER: {}", (Object)userMessage);
            this.logger.debug("--- User message tokens: {}", (Object)numUserTokens);
            return userMessage;
        }
        this.logger.warn("User message cannot fit into token limit");
        return "";
    }

    private ChatResponse executeChatCompletionRequest(ChatRequest chatCompletionRequest) {
        try {
            ChatResponse chatCompletionResult = this.requestExecutor.execute(chatCompletionRequest);
            TokenUsage usage = chatCompletionResult.tokenUsage();
            if (usage != null) {
                QuotaUtils.consumeTokens(usage.inputTokenCount().intValue(), usage.outputTokenCount().intValue());
            }
            return chatCompletionResult;
        }
        catch (IOException e) {
            throw new QueryEvaluationException((Throwable)e);
        }
        catch (AuthenticationException e) {
            throw new QueryEvaluationException(this.extractErrorMessage(e.getMessage()));
        }
        catch (RuntimeException e) {
            Throwable cause = e.getCause();
            if (cause instanceof SocketTimeoutException) {
                throw new QueryEvaluationException("GPT timed out");
            }
            throw e;
        }
    }

    private ChatResponse defaultChatCompletionRequestExecute(ChatRequest chatCompletionRequest) {
        ChatModel chatModel = GptDefaultClient.getChatModel(chatCompletionRequest);
        return chatModel.chat(chatCompletionRequest.messages());
    }

    private String extractErrorMessage(String json) {
        try {
            JsonNode root = this.objectMapper.readTree(json);
            return root.path("error").path("message").asText();
        }
        catch (Exception e) {
            return "Unknown error: " + e.getMessage();
        }
    }

    private String getLLMFunctionsModel() {
        String model = Config.getPropertyOrLegacyProperty((String)GPT_MODEL_PROPERTY, (String)LEGACY_GPT_MODEL_PROPERTY);
        if (model == null) {
            return SystemConfig.getDefaultLlmModel();
        }
        return model;
    }

    private String sanitizeModel() {
        if (this.llmModel != null) {
            return this.llmModel.replaceFirst("(-([0-9]{4}|[0-9]{4}-[0-9]{2}-[0-9]{2}|preview))+$", "");
        }
        return null;
    }

    private void setupLLMModelConfiguration(ModelType modelType, EncodingRegistry registry) {
        this.llmModelMaxTokens = modelType.getMaxContextLength();
        this.gptEncoding = registry.getEncodingForModel(modelType);
    }

    private static ChatModel getChatModel(ChatRequest chatCompletionRequest) {
        String modelName = chatCompletionRequest.modelName();
        double temperature = chatCompletionRequest.temperature() != null ? chatCompletionRequest.temperature() : 0.7;
        double topP = chatCompletionRequest.topP() != null ? chatCompletionRequest.topP() : 1.0;
        String api = LLMConfig.getLLMApi();
        if (StringUtils.isEmpty((CharSequence)api) || api.toLowerCase(Locale.ROOT).startsWith("openai-assistants")) {
            return LLMConfig.isAzure() ? ChatModelFactory.createAzureOpenAIChatModel(modelName, temperature, topP) : ChatModelFactory.createOpenAIChatModel(modelName, temperature, topP);
        }
        return ChatModelFactory.createChatModel(modelName, temperature, topP);
    }

    public static interface RequestExecutor {
        public ChatResponse execute(ChatRequest var1) throws IOException;
    }
}

