/*
 * Decompiled with CFR 0.152.
 */
package com.ontotext.forest.gpt.chat.completions;

import com.ontotext.forest.gpt.chat.GptChatMessage;
import com.ontotext.forest.gpt.chat.completions.ConversationService;
import com.ontotext.forest.gpt.chat.completions.persistence.ConversationPersisted;
import com.ontotext.forest.gpt.chat.completions.persistence.ConversationResponsePersisted;
import com.ontotext.forest.gpt.ttyg.AgentConfig;
import com.ontotext.graphdb.configs.LLMConfig;
import com.ontotext.graphdb.gpt.ChatModelFactory;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.memory.chat.TokenWindowChatMemory;
import dev.langchain4j.model.TokenCountEstimator;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import javax.annotation.concurrent.ThreadSafe;

@ThreadSafe
public class MemoryService
implements ChatMemoryProvider {
    private static final String LLM_API = LLMConfig.getLLMApi();
    private final ConversationService conversationService;
    private final Map<String, ToolExecutionResultAwareChatMemory> conversationMemory;
    private final Map<String, String> conversationToLastAgentMap;
    private final Map<String, String> lastAgentToConversationMap;
    private final Map<String, TokenCountEstimator> tokenCountEstimatorInstances;

    public MemoryService(ConversationService conversationService) {
        this.conversationService = conversationService;
        this.conversationMemory = new HashMap<String, ToolExecutionResultAwareChatMemory>();
        this.conversationToLastAgentMap = new HashMap<String, String>();
        this.lastAgentToConversationMap = new HashMap<String, String>();
        this.tokenCountEstimatorInstances = new HashMap<String, TokenCountEstimator>();
    }

    public synchronized ChatMemory get(Object threadId) {
        return this.conversationMemory.get(threadId);
    }

    public synchronized ChatMemory getOrCreateMemory(String threadId, AgentConfig agentConfig) {
        if (agentConfig.getId().equals(this.conversationToLastAgentMap.get(threadId))) {
            ToolExecutionResultAwareChatMemory currentMemory = this.conversationMemory.get(threadId);
            if (currentMemory.getContextSize() == agentConfig.getContextSize()) {
                return currentMemory;
            }
            return this.createMemoryFromConversation(threadId, agentConfig);
        }
        return this.createMemoryFromConversation(threadId, agentConfig);
    }

    public synchronized String getAgentForMemory(String threadId) {
        return this.conversationToLastAgentMap.get(threadId);
    }

    private ChatMemory createMemoryFromConversation(String threadId, AgentConfig agentConfig) {
        ToolExecutionResultAwareChatMemory currentAgentMemory = this.conversationMemory.get(threadId);
        if (currentAgentMemory == null || currentAgentMemory.getContextSize() != agentConfig.getContextSize()) {
            currentAgentMemory = this.createNewMemory(agentConfig);
        }
        currentAgentMemory.clear();
        this.loadMessagesForConversation(threadId).forEach(currentAgentMemory::add);
        this.conversationToLastAgentMap.put(threadId, agentConfig.getId());
        this.lastAgentToConversationMap.put(agentConfig.getId(), threadId);
        this.conversationMemory.put(threadId, currentAgentMemory);
        return currentAgentMemory;
    }

    public synchronized void rebuildMemoryIfNeeded(String threadId, AgentConfig agentConfig) {
        ToolExecutionResultAwareChatMemory memory = this.conversationMemory.get(threadId);
        if (memory.containsToolExecutionResponse()) {
            this.createMemoryFromConversation(threadId, agentConfig);
            memory.setContainsToolExecutionResponse(false);
        }
    }

    private List<ChatMessage> loadMessagesForConversation(String threadId) {
        LinkedList<ChatMessage> chatMessages = new LinkedList<ChatMessage>();
        ConversationPersisted conversationPersisted = this.conversationService.getConversation(threadId);
        for (ConversationResponsePersisted conversationResponsePersisted : conversationPersisted.getConversationResponses()) {
            for (GptChatMessage message : conversationResponsePersisted.getMessages()) {
                if (message.getRole().equalsIgnoreCase("assistant")) {
                    chatMessages.add((ChatMessage)AiMessage.aiMessage((String)message.getMessage()));
                    continue;
                }
                if (!message.getRole().equalsIgnoreCase("user")) continue;
                chatMessages.add((ChatMessage)UserMessage.userMessage((String)message.getMessage()));
            }
        }
        return chatMessages;
    }

    private TokenCountEstimator getTokenCountEstimatorInstance(String modelName) {
        this.tokenCountEstimatorInstances.putIfAbsent(LLM_API, ChatModelFactory.getTokenEstimator((String)modelName));
        return this.tokenCountEstimatorInstances.get(LLM_API);
    }

    private ToolExecutionResultAwareChatMemory createNewMemory(AgentConfig agentConfig) {
        TokenWindowChatMemory memory = TokenWindowChatMemory.withMaxTokens((int)agentConfig.getContextSize(), (TokenCountEstimator)this.getTokenCountEstimatorInstance(agentConfig.getModel()));
        return new ToolExecutionResultAwareChatMemory((ChatMemory)memory, agentConfig.getContextSize());
    }

    public synchronized void removeAgentMemory(String agentId) {
        String threadId = this.lastAgentToConversationMap.remove(agentId);
        this.conversationToLastAgentMap.remove(threadId);
        this.conversationMemory.remove(threadId);
    }

    public synchronized void removeConversationMemory(String threadId) {
        this.conversationToLastAgentMap.remove(threadId);
        this.conversationMemory.remove(threadId);
    }

    private static class ToolExecutionResultAwareChatMemory
    implements ChatMemory {
        private final ChatMemory delegate;
        private final int contextSize;
        private boolean containsToolExecutionResponse;

        public ToolExecutionResultAwareChatMemory(ChatMemory chatMemory, int contextSize) {
            this.delegate = chatMemory;
            this.contextSize = contextSize;
            this.containsToolExecutionResponse = false;
        }

        public Object id() {
            return this.delegate.id();
        }

        public void add(ChatMessage message) {
            if (message instanceof ToolExecutionResultMessage) {
                this.containsToolExecutionResponse = true;
            }
            this.delegate.add(message);
        }

        public List<ChatMessage> messages() {
            return this.delegate.messages();
        }

        public void clear() {
            this.delegate.clear();
        }

        public boolean containsToolExecutionResponse() {
            return this.containsToolExecutionResponse;
        }

        public void setContainsToolExecutionResponse(boolean shouldRebuild) {
            this.containsToolExecutionResponse = shouldRebuild;
        }

        public int getContextSize() {
            return this.contextSize;
        }
    }
}

