/*
 * Decompiled with CFR 0.152.
 */
package com.ontotext.forest.gpt.chat;

import com.azure.ai.openai.assistants.models.AssistantStreamEvent;
import com.azure.ai.openai.assistants.models.CreateRunOptions;
import com.azure.ai.openai.assistants.models.MessageRole;
import com.azure.ai.openai.assistants.models.RequiredAction;
import com.azure.ai.openai.assistants.models.RequiredFunctionToolCall;
import com.azure.ai.openai.assistants.models.RequiredFunctionToolCallDetails;
import com.azure.ai.openai.assistants.models.RequiredToolCall;
import com.azure.ai.openai.assistants.models.RunCompletionUsage;
import com.azure.ai.openai.assistants.models.RunError;
import com.azure.ai.openai.assistants.models.RunStatus;
import com.azure.ai.openai.assistants.models.StreamMessageCreation;
import com.azure.ai.openai.assistants.models.StreamRequiredAction;
import com.azure.ai.openai.assistants.models.StreamThreadRunCreation;
import com.azure.ai.openai.assistants.models.StreamUpdate;
import com.azure.ai.openai.assistants.models.SubmitToolOutputsAction;
import com.azure.ai.openai.assistants.models.ThreadMessage;
import com.azure.ai.openai.assistants.models.ThreadMessageOptions;
import com.azure.ai.openai.assistants.models.ThreadRun;
import com.azure.ai.openai.assistants.models.ToolOutput;
import com.azure.core.exception.AzureException;
import com.azure.core.exception.HttpResponseException;
import com.azure.core.util.IterableStream;
import com.ontotext.forest.gpt.chat.GptAssistantChat;
import com.ontotext.forest.gpt.conversations.CancelResponse;
import com.ontotext.forest.gpt.conversations.Usage;
import com.ontotext.forest.gpt.ttyg.AgentConfig;
import com.ontotext.forest.gpt.ttyg.AzureUtil;
import com.ontotext.forest.gpt.ttyg.exceptions.GenericServerException;
import com.ontotext.forest.gpt.ttyg.exceptions.ToolConfigException;
import dev.failsafe.Failsafe;
import dev.failsafe.Policy;
import dev.failsafe.RetryPolicy;
import dev.failsafe.RetryPolicyBuilder;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class ChatRunner {
    private final Logger logger = LoggerFactory.getLogger(ChatRunner.class);
    private final GptAssistantChat chatClient;
    private final AgentConfig agentConfig;
    private final String threadId;
    private final Consumer<String> onFinalStream;
    private IterableStream<StreamUpdate> runStream;
    private volatile String runId;
    static final String TIMEOUT_CANCEL_ERROR_MESSAGE = "Timeout waiting for chat to start.";
    public static final String INTERRUPTED_ERROR_MESSAGE = "Interrupted waiting for chat to start in conversation ";
    public static final String EXECUTION_ERROR_MESSAGE = "Execution exception during chat in conversation ";
    private static final String DEFAULT_RUN_ERROR_MESSAGE = "Sorry, something went wrong.";
    private static final String DEFAULT_RUN_ERROR_CODE = "server_error";
    private static final Set<RunStatus> CANCELLABLE_STATUSES = Set.of(RunStatus.IN_PROGRESS, RunStatus.QUEUED, RunStatus.REQUIRES_ACTION);
    private static final Set<RunStatus> TERMINAL_STATUSES_NO_CANCEL = Set.of(RunStatus.COMPLETED, RunStatus.FAILED, RunStatus.EXPIRED);

    ChatRunner(GptAssistantChat chatClient, String threadId, String userMessage, AgentConfig agentConfig, Consumer<String> onFinalStream) {
        Objects.requireNonNull(chatClient);
        Objects.requireNonNull(threadId);
        Objects.requireNonNull(userMessage);
        Objects.requireNonNull(agentConfig);
        Objects.requireNonNull(onFinalStream);
        this.chatClient = chatClient;
        this.threadId = threadId;
        this.agentConfig = agentConfig;
        this.onFinalStream = onFinalStream;
        chatClient.assistantsOpenAIClient.createMessage(threadId, new ThreadMessageOptions(MessageRole.USER, userMessage));
        this.runStream = chatClient.assistantsOpenAIClient.createRunStream(threadId, new CreateRunOptions(agentConfig.getId()));
    }

    ChatRunnerResult getNextResult(Integer tzOffset) {
        try {
            ChatRunnerResult result;
            while ((result = this.processRunStream(tzOffset)).getMessages().isEmpty() && result.getContinueRunId() != null) {
            }
            if (result.getContinueRunId() == null) {
                this.onFinalStream();
            }
            return result;
        }
        catch (Throwable e) {
            this.onFinalStream();
            try {
                this.cancelRunIfNecessary();
            }
            finally {
                throw e;
            }
        }
    }

    private ChatRunnerResult processRunStream(Integer tzOffset) {
        if (this.runStream == null) {
            throw new IllegalStateException("All streams have been consumed");
        }
        if (this.runId == null) {
            this.logger.debug("Run stream started");
        } else {
            this.logger.debug("Run stream continued: {}", (Object)this.runId);
        }
        ArrayList<ThreadMessage> messages = new ArrayList<ThreadMessage>();
        Usage usage = new Usage();
        for (StreamUpdate update : this.runStream) {
            this.runStream = null;
            AssistantStreamEvent kind = update.getKind();
            if (!kind.equals((Object)AssistantStreamEvent.THREAD_MESSAGE_DELTA) && !kind.equals((Object)AssistantStreamEvent.THREAD_RUN_STEP_DELTA)) {
                this.logger.debug("Run event: {}", (Object)kind);
            }
            if (kind.equals((Object)AssistantStreamEvent.THREAD_RUN_CREATED)) {
                this.runId = ((StreamThreadRunCreation)update).getMessage().getId();
                continue;
            }
            if (kind.equals((Object)AssistantStreamEvent.THREAD_MESSAGE_COMPLETED)) {
                ThreadMessage message = ((StreamMessageCreation)update).getMessage();
                messages.add(message);
                continue;
            }
            if (kind.equals((Object)AssistantStreamEvent.THREAD_RUN_REQUIRES_ACTION)) {
                this.runStream = this.handleRequiredAction((StreamRequiredAction)update, tzOffset);
                continue;
            }
            if (kind.equals((Object)AssistantStreamEvent.THREAD_RUN_COMPLETED)) {
                this.logger.debug("Run stream completed: {}", (Object)this.runId);
                this.processUsageTokens(update, usage);
                continue;
            }
            if (kind.equals((Object)AssistantStreamEvent.THREAD_RUN_CANCELLED)) {
                this.processUsageTokens(update, usage);
                messages.clear();
                break;
            }
            if (!kind.equals((Object)AssistantStreamEvent.THREAD_RUN_FAILED) && !kind.equals((Object)AssistantStreamEvent.THREAD_RUN_EXPIRED)) continue;
            this.processUsageTokens(update, usage);
            this.handleFailedRun(update);
        }
        if (messages.isEmpty()) {
            this.logger.debug("Chat run stream did not create any messages");
        }
        String continueRunId = this.runStream != null ? this.runId : null;
        return new ChatRunnerResult(messages, continueRunId, usage);
    }

    private IterableStream<StreamUpdate> handleRequiredAction(StreamRequiredAction actionUpdate, Integer tzOffset) {
        ThreadRun run = actionUpdate.getMessage();
        RequiredAction action = run.getRequiredAction();
        if (action instanceof SubmitToolOutputsAction) {
            List<ToolOutput> outputs = this.invokeAllTools((SubmitToolOutputsAction)action, this.agentConfig, tzOffset);
            try {
                return this.chatClient.assistantsOpenAIClient.submitToolOutputsToRunStream(this.threadId, run.getId(), outputs);
            }
            catch (HttpResponseException e) {
                String errorMessage = AzureUtil.getAzureErrorMessage((AzureException)e);
                if (errorMessage != null && errorMessage.contains("Runs in status \"cancelled\" do not accept tool outputs")) {
                    this.logger.warn("Run {} was cancelled while processing. Skipping tool output submission.", (Object)run.getId());
                    return null;
                }
                throw e;
            }
        }
        throw new GenericServerException("Unexpected required action class: " + action.getClass().getSimpleName());
    }

    private List<ToolOutput> invokeAllTools(SubmitToolOutputsAction requiredAction, AgentConfig agentConfig, Integer tzOffset) {
        ArrayList<ToolOutput> toolOutputs = new ArrayList<ToolOutput>();
        for (RequiredToolCall toolCall : requiredAction.getSubmitToolOutputs().getToolCalls()) {
            toolOutputs.add(this.invokeTool(toolCall, agentConfig, tzOffset));
        }
        this.shortenToolOutputsIfNeeded(toolOutputs);
        return toolOutputs;
    }

    private ToolOutput invokeTool(RequiredToolCall toolCall, AgentConfig agentConfig, Integer tzOffset) {
        if (toolCall instanceof RequiredFunctionToolCall) {
            RequiredFunctionToolCall functionCall = (RequiredFunctionToolCall)toolCall;
            RequiredFunctionToolCallDetails function = functionCall.getFunction();
            String functionName = function.getName();
            String functionArgs = function.getArguments();
            String output = this.chatClient.callTool(functionName, functionArgs, agentConfig, tzOffset, this.threadId);
            return new ToolOutput().setToolCallId(toolCall.getId()).setOutput(output);
        }
        throw new ToolConfigException("Unsupported tool type for call " + toolCall.getId() + ": " + toolCall.getType());
    }

    private void onFinalStream() {
        if (this.runId != null) {
            this.onFinalStream.accept(this.runId);
        }
    }

    long getToolOutputsSize(List<ToolOutput> toolOutputs) {
        AtomicLong totalSize = new AtomicLong(0L);
        toolOutputs.forEach(result -> totalSize.addAndGet(result.getOutput().length()));
        return totalSize.get();
    }

    String shortenEntryOutput(String toolOutput) {
        List lines = toolOutput.lines().collect(Collectors.toList());
        return String.join((CharSequence)"\n", lines.subList(0, lines.size() - 1));
    }

    void shortenToolOutputsIfNeeded(List<ToolOutput> toolOutputs) {
        long totalSize = this.getToolOutputsSize(toolOutputs);
        if (totalSize < 256000L) {
            return;
        }
        while (totalSize > 256000L) {
            long prevSize = totalSize;
            for (ToolOutput toolOutput : toolOutputs) {
                long currentOutputLength = toolOutput.getOutput().length();
                if (currentOutputLength <= 10000L) continue;
                String shortenedEntryOutput = this.shortenEntryOutput(toolOutput.getOutput());
                toolOutput.setOutput(shortenedEntryOutput);
                totalSize = totalSize - currentOutputLength + (long)shortenedEntryOutput.length();
            }
            if (totalSize != prevSize) continue;
            throw new GenericServerException("Unable to shorten tool outputs to fit into limit");
        }
    }

    private void processUsageTokens(StreamUpdate update, Usage usage) {
        RunCompletionUsage usageMessage = ((StreamThreadRunCreation)update).getMessage().getUsage();
        if (usageMessage != null) {
            usage.setCompletionTokens(usageMessage.getCompletionTokens());
            usage.setPromptTokens(usageMessage.getPromptTokens());
            usage.setTotalTokens(usageMessage.getTotalTokens());
        }
    }

    private void handleFailedRun(StreamUpdate update) {
        ThreadRun run = ((StreamThreadRunCreation)update).getMessage();
        RunError runError = run.getLastError();
        throw new GenericServerException(String.format("Chat run failed: %s (%s)", runError != null ? runError.getMessage() : DEFAULT_RUN_ERROR_MESSAGE, runError != null ? runError.getCode() : DEFAULT_RUN_ERROR_CODE));
    }

    private void cancelRunIfNecessary() {
        if (this.runId == null) {
            return;
        }
        try {
            ThreadRun currentRun = this.chatClient.assistantsOpenAIClient.getRun(this.threadId, this.runId);
            if (currentRun != null && this.isCancellable(currentRun.getStatus())) {
                this.logger.warn("Cancelling run {} due to processing failure", (Object)this.runId);
                this.chatClient.assistantsOpenAIClient.cancelRun(this.threadId, this.runId);
            }
        }
        catch (Exception cancelException) {
            this.logger.error("Failed to cancel run {}: {}", (Object)this.runId, (Object)cancelException.getMessage());
        }
    }

    private boolean isCancellable(RunStatus status) {
        return status != null && CANCELLABLE_STATUSES.contains(status);
    }

    private boolean isTerminal(RunStatus status) {
        return status != null && TERMINAL_STATUSES_NO_CANCEL.contains(status);
    }

    CancelResponse cancel() {
        this.logger.info("Attempting to cancel current chat run.");
        RunStatus runStatus = this.cancelRun(this.runId);
        if (runStatus == RunStatus.CANCELLED) {
            return new CancelResponse(runStatus, "Request cancelled by the user.");
        }
        if (runStatus == RunStatus.EXPIRED) {
            return new CancelResponse(runStatus, "Request expired before completion.");
        }
        if (runStatus == RunStatus.FAILED) {
            return new CancelResponse(runStatus, "Request failed due to an error.");
        }
        return new CancelResponse(Objects.requireNonNull(runStatus));
    }

    private synchronized RunStatus cancelRun(String currentRunId) {
        String terminalStateMessage = "Run %s can not be cancelled as it is in a terminal state: %s";
        String successCancel = "Chat run %s was successfully cancelled.";
        ThreadRun threadRun = this.chatClient.assistantsOpenAIClient.getRun(this.threadId, currentRunId);
        if (threadRun == null) {
            this.logger.warn("Run {} not found, possibly already removed.", (Object)currentRunId);
            return null;
        }
        RunStatus finalStatus = threadRun.getStatus();
        if (this.processRunStatus(terminalStateMessage, successCancel, currentRunId, finalStatus)) {
            return finalStatus;
        }
        if (this.isCancellable(finalStatus)) {
            try {
                this.chatClient.assistantsOpenAIClient.cancelRun(this.threadId, currentRunId);
            }
            catch (Exception e) {
                threadRun = this.chatClient.assistantsOpenAIClient.getRun(this.threadId, currentRunId);
                finalStatus = threadRun.getStatus();
                if (this.isTerminal(finalStatus)) {
                    this.logger.info(String.format(terminalStateMessage, currentRunId, finalStatus.getValue()));
                    return finalStatus;
                }
                if (finalStatus == RunStatus.CANCELLED) {
                    this.logger.info(String.format(successCancel, currentRunId));
                    return finalStatus;
                }
                throw e;
            }
        }
        if (this.processRunStatus(terminalStateMessage, successCancel, currentRunId, finalStatus = this.pollForCancellation(currentRunId))) {
            return finalStatus;
        }
        return null;
    }

    private boolean processRunStatus(String terminalStateMessage, String successCancel, String currentRunId, RunStatus finalStatus) {
        if (this.isTerminal(finalStatus)) {
            this.logger.info(String.format(terminalStateMessage, currentRunId, finalStatus.getValue()));
            return true;
        }
        if (finalStatus == RunStatus.CANCELLED) {
            this.logger.info(String.format(successCancel, currentRunId));
            return true;
        }
        return false;
    }

    private RunStatus pollForCancellation(String runId) {
        return this.pollWithRetryPolicy(() -> {
            ThreadRun run = this.chatClient.assistantsOpenAIClient.getRun(this.threadId, runId);
            if (run == null) {
                return RunStatus.FAILED;
            }
            return run.getStatus();
        }, runId);
    }

    private RunStatus pollWithRetryPolicy(Supplier<RunStatus> condition, String runId) {
        try {
            return (RunStatus)Failsafe.with(this.retryPolicyRunStatus(), (Policy[])new RetryPolicy[0]).onFailure(event -> {
                if (event.getException() instanceof InterruptedException) {
                    Thread.currentThread().interrupt();
                }
            }).get(condition::get);
        }
        catch (Exception e) {
            ThreadRun run = this.chatClient.assistantsOpenAIClient.getRun(this.threadId, runId);
            if (run == null) {
                this.logger.warn("waiting for run to be cancelled failed due to: {}", (Object)e.getMessage());
                return RunStatus.FAILED;
            }
            this.logger.warn("waiting for run {} to be cancelled failed due to: {}", (Object)runId, (Object)e.getMessage());
            return run.getStatus();
        }
    }

    private RetryPolicy<RunStatus> retryPolicyRunStatus() {
        return ((RetryPolicyBuilder)RetryPolicy.builder().handleResultIf(status -> status != RunStatus.CANCELLED && !this.isTerminal((RunStatus)status))).withBackoff(1L, 30L, ChronoUnit.SECONDS, 2.0).withMaxRetries(-1).build();
    }

    static class ChatRunnerResult {
        private final List<ThreadMessage> messages;
        private final String continueRunId;
        private final Usage usage;

        ChatRunnerResult(List<ThreadMessage> messages, String continueRunId, Usage usage) {
            this.messages = messages;
            this.continueRunId = continueRunId;
            this.usage = usage;
        }

        List<ThreadMessage> getMessages() {
            return this.messages;
        }

        String getContinueRunId() {
            return this.continueRunId;
        }

        Usage getUsage() {
            return this.usage;
        }
    }
}

