/*
 * Decompiled with CFR 0.152.
 */
package com.ontotext.forest.gpt.chat;

import com.azure.ai.openai.assistants.models.AssistantStreamEvent;
import com.azure.ai.openai.assistants.models.CreateRunOptions;
import com.azure.ai.openai.assistants.models.MessageRole;
import com.azure.ai.openai.assistants.models.RequiredAction;
import com.azure.ai.openai.assistants.models.RequiredFunctionToolCall;
import com.azure.ai.openai.assistants.models.RequiredFunctionToolCallDetails;
import com.azure.ai.openai.assistants.models.RequiredToolCall;
import com.azure.ai.openai.assistants.models.RunCompletionUsage;
import com.azure.ai.openai.assistants.models.RunError;
import com.azure.ai.openai.assistants.models.RunStatus;
import com.azure.ai.openai.assistants.models.StreamMessageCreation;
import com.azure.ai.openai.assistants.models.StreamRequiredAction;
import com.azure.ai.openai.assistants.models.StreamThreadRunCreation;
import com.azure.ai.openai.assistants.models.StreamUpdate;
import com.azure.ai.openai.assistants.models.SubmitToolOutputsAction;
import com.azure.ai.openai.assistants.models.ThreadMessage;
import com.azure.ai.openai.assistants.models.ThreadMessageOptions;
import com.azure.ai.openai.assistants.models.ThreadRun;
import com.azure.ai.openai.assistants.models.ToolOutput;
import com.azure.core.util.IterableStream;
import com.ontotext.forest.gpt.chat.GptAssistantChat;
import com.ontotext.forest.gpt.conversations.Usage;
import com.ontotext.forest.gpt.ttyg.AgentConfig;
import com.ontotext.forest.gpt.ttyg.exceptions.GenericServerException;
import com.ontotext.forest.gpt.ttyg.exceptions.ToolConfigException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.eclipse.rdf4j.repository.RepositoryConnection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class ChatRunner {
    private final Logger logger = LoggerFactory.getLogger(ChatRunner.class);
    private final GptAssistantChat chatClient;
    private final AgentConfig agentConfig;
    private final String threadId;
    private final Consumer<String> onFinalStream;
    private IterableStream<StreamUpdate> runStream;
    private String runId;
    private static final String DEFAULT_RUN_ERROR_MESSAGE = "Sorry, something went wrong.";
    private static final String DEFAULT_RUN_ERROR_CODE = "server_error";
    private static final Set<RunStatus> CANCELLABLE_STATUSES = Set.of(RunStatus.IN_PROGRESS, RunStatus.QUEUED, RunStatus.REQUIRES_ACTION);

    ChatRunner(GptAssistantChat chatClient, String threadId, String userMessage, AgentConfig agentConfig, Consumer<String> onFinalStream) {
        Objects.requireNonNull(chatClient);
        Objects.requireNonNull(threadId);
        Objects.requireNonNull(userMessage);
        Objects.requireNonNull(agentConfig);
        Objects.requireNonNull(onFinalStream);
        this.chatClient = chatClient;
        this.threadId = threadId;
        this.agentConfig = agentConfig;
        this.onFinalStream = onFinalStream;
        chatClient.assistantsOpenAIClient.createMessage(threadId, new ThreadMessageOptions(MessageRole.USER, userMessage));
        this.runStream = chatClient.assistantsOpenAIClient.createRunStream(threadId, new CreateRunOptions(agentConfig.getId()));
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    ChatRunnerResult getNextResult(Integer tzOffset, RepositoryConnection connection) {
        try {
            ChatRunnerResult result;
            while ((result = this.processRunStream(tzOffset, connection)).getMessages().isEmpty() && result.getContinueRunId() != null) {
            }
            if (result.getContinueRunId() == null) {
                this.onFinalStream();
            }
            return result;
        }
        catch (Throwable e) {
            this.onFinalStream();
            try {
                this.cancelRunIfNecessary();
            }
            finally {
                throw e;
            }
        }
    }

    private ChatRunnerResult processRunStream(Integer tzOffset, RepositoryConnection connection) {
        if (this.runStream == null) {
            throw new IllegalStateException("All streams have been consumed");
        }
        if (this.runId == null) {
            this.logger.debug("Run stream started");
        } else {
            this.logger.debug("Run stream continued: {}", (Object)this.runId);
        }
        ArrayList<ThreadMessage> messages = new ArrayList<ThreadMessage>();
        Usage usage = new Usage();
        for (StreamUpdate update : this.runStream) {
            this.runStream = null;
            AssistantStreamEvent kind = update.getKind();
            if (!kind.equals((Object)AssistantStreamEvent.THREAD_MESSAGE_DELTA) && !kind.equals((Object)AssistantStreamEvent.THREAD_RUN_STEP_DELTA)) {
                this.logger.debug("Run event: {}", (Object)kind);
            }
            if (kind.equals((Object)AssistantStreamEvent.THREAD_RUN_CREATED)) {
                this.runId = ((StreamThreadRunCreation)update).getMessage().getId();
                continue;
            }
            if (kind.equals((Object)AssistantStreamEvent.THREAD_MESSAGE_COMPLETED)) {
                ThreadMessage message = ((StreamMessageCreation)update).getMessage();
                messages.add(message);
                continue;
            }
            if (kind.equals((Object)AssistantStreamEvent.THREAD_RUN_REQUIRES_ACTION)) {
                this.runStream = this.handleRequiredAction((StreamRequiredAction)update, tzOffset);
                continue;
            }
            if (kind.equals((Object)AssistantStreamEvent.THREAD_RUN_COMPLETED)) {
                this.logger.debug("Run stream completed: {}", (Object)this.runId);
                this.processUsageTokens(update, usage);
                continue;
            }
            if (!kind.equals((Object)AssistantStreamEvent.THREAD_RUN_FAILED) && !kind.equals((Object)AssistantStreamEvent.THREAD_RUN_CANCELLED) && !kind.equals((Object)AssistantStreamEvent.THREAD_RUN_EXPIRED)) continue;
            this.handleFailedRun(update);
        }
        if (messages.isEmpty()) {
            this.logger.debug("Chat run stream did not create any messages");
        }
        String continueRunId = this.runStream != null ? this.runId : null;
        return new ChatRunnerResult(messages, continueRunId, usage);
    }

    private IterableStream<StreamUpdate> handleRequiredAction(StreamRequiredAction actionUpdate, Integer tzOffset) {
        ThreadRun run = actionUpdate.getMessage();
        RequiredAction action = run.getRequiredAction();
        if (action instanceof SubmitToolOutputsAction) {
            List<ToolOutput> outputs = this.invokeAllTools((SubmitToolOutputsAction)action, this.agentConfig, tzOffset);
            return this.chatClient.assistantsOpenAIClient.submitToolOutputsToRunStream(this.threadId, run.getId(), outputs);
        }
        throw new GenericServerException("Unexpected required action class: " + action.getClass().getSimpleName());
    }

    private List<ToolOutput> invokeAllTools(SubmitToolOutputsAction requiredAction, AgentConfig agentConfig, Integer tzOffset) {
        ArrayList<ToolOutput> toolOutputs = new ArrayList<ToolOutput>();
        for (RequiredToolCall toolCall : requiredAction.getSubmitToolOutputs().getToolCalls()) {
            toolOutputs.add(this.invokeTool(toolCall, agentConfig, tzOffset));
        }
        this.shortenToolOutputsIfNeeded(toolOutputs);
        return toolOutputs;
    }

    private ToolOutput invokeTool(RequiredToolCall toolCall, AgentConfig agentConfig, Integer tzOffset) {
        if (toolCall instanceof RequiredFunctionToolCall) {
            RequiredFunctionToolCall functionCall = (RequiredFunctionToolCall)toolCall;
            RequiredFunctionToolCallDetails function = functionCall.getFunction();
            String functionName = function.getName();
            String functionArgs = function.getArguments();
            String output = this.chatClient.callTool(functionName, functionArgs, agentConfig, tzOffset);
            return new ToolOutput().setToolCallId(toolCall.getId()).setOutput(output);
        }
        throw new ToolConfigException("Unsupported tool type for call " + toolCall.getId() + ": " + toolCall.getType());
    }

    private void onFinalStream() {
        if (this.runId != null) {
            this.onFinalStream.accept(this.runId);
        }
    }

    long getToolOutputsSize(List<ToolOutput> toolOutputs) {
        AtomicLong totalSize = new AtomicLong(0L);
        toolOutputs.forEach(result -> totalSize.addAndGet(result.getOutput().length()));
        return totalSize.get();
    }

    String shortenEntryOutput(String toolOutput) {
        List lines = toolOutput.lines().collect(Collectors.toList());
        return String.join((CharSequence)"\n", lines.subList(0, lines.size() - 1));
    }

    void shortenToolOutputsIfNeeded(List<ToolOutput> toolOutputs) {
        long totalSize = this.getToolOutputsSize(toolOutputs);
        if (totalSize < 256000L) {
            return;
        }
        while (totalSize > 256000L) {
            long prevSize = totalSize;
            for (ToolOutput toolOutput : toolOutputs) {
                long currentOutputLength = toolOutput.getOutput().length();
                if (currentOutputLength <= 10000L) continue;
                String shortenedEntryOutput = this.shortenEntryOutput(toolOutput.getOutput());
                toolOutput.setOutput(shortenedEntryOutput);
                totalSize = totalSize - currentOutputLength + (long)shortenedEntryOutput.length();
            }
            if (totalSize != prevSize) continue;
            throw new GenericServerException("Unable to shorten tool outputs to fit into limit");
        }
    }

    private void processUsageTokens(StreamUpdate update, Usage usage) {
        RunCompletionUsage usageMessage = ((StreamThreadRunCreation)update).getMessage().getUsage();
        if (usageMessage != null) {
            usage.setCompletionTokens(usageMessage.getCompletionTokens());
            usage.setPromptTokens(usageMessage.getPromptTokens());
            usage.setTotalTokens(usageMessage.getTotalTokens());
        }
    }

    private void handleFailedRun(StreamUpdate update) {
        ThreadRun run = ((StreamThreadRunCreation)update).getMessage();
        RunError runError = run.getLastError();
        throw new GenericServerException(String.format("Chat run failed: %s (%s)", runError != null ? runError.getMessage() : DEFAULT_RUN_ERROR_MESSAGE, runError != null ? runError.getCode() : DEFAULT_RUN_ERROR_CODE));
    }

    private void cancelRunIfNecessary() {
        if (this.runId == null) {
            return;
        }
        try {
            ThreadRun currentRun = this.chatClient.assistantsOpenAIClient.getRun(this.threadId, this.runId);
            if (currentRun != null && this.isCancellable(currentRun.getStatus())) {
                this.logger.warn("Cancelling run {} due to processing failure", (Object)this.runId);
                this.chatClient.assistantsOpenAIClient.cancelRun(this.threadId, this.runId);
            }
        }
        catch (Exception cancelException) {
            this.logger.error("Failed to cancel run {}: {}", (Object)this.runId, (Object)cancelException.getMessage());
        }
    }

    private boolean isCancellable(RunStatus status) {
        return status != null && CANCELLABLE_STATUSES.contains(status);
    }

    static class ChatRunnerResult {
        private final List<ThreadMessage> messages;
        private final String continueRunId;
        private final Usage usage;

        ChatRunnerResult(List<ThreadMessage> messages, String continueRunId, Usage usage) {
            this.messages = messages;
            this.continueRunId = continueRunId;
            this.usage = usage;
        }

        List<ThreadMessage> getMessages() {
            return this.messages;
        }

        String getContinueRunId() {
            return this.continueRunId;
        }

        Usage getUsage() {
            return this.usage;
        }
    }
}

