/*
 * Decompiled with CFR 0.152.
 */
package com.ontotext.forest.mcp.server;

import com.ontotext.forest.mcp.server.GraphDBMcpServerSession;
import com.ontotext.graphdb.Config;
import com.ontotext.trree.statistics.SystemStatisticsCollector;
import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.json.McpJsonMapper;
import io.modelcontextprotocol.json.TypeRef;
import io.modelcontextprotocol.server.McpTransportContextExtractor;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.spec.McpServerTransport;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
import io.modelcontextprotocol.util.Assert;
import io.modelcontextprotocol.util.KeepAliveScheduler;
import java.io.IOException;
import java.net.URI;
import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantLock;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpStatus;
import org.springframework.http.HttpStatusCode;
import org.springframework.web.servlet.function.RouterFunction;
import org.springframework.web.servlet.function.RouterFunctions;
import org.springframework.web.servlet.function.ServerRequest;
import org.springframework.web.servlet.function.ServerResponse;
import org.springframework.web.util.UriComponentsBuilder;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

public class GraphDBWebMvcSseServerTransportProvider
implements McpServerTransportProvider {
    private static final Logger logger = LoggerFactory.getLogger(GraphDBWebMvcSseServerTransportProvider.class);
    public static final String MESSAGE_EVENT_TYPE = "message";
    public static final String ENDPOINT_EVENT_TYPE = "endpoint";
    public static final String DEFAULT_SSE_ENDPOINT = "/sse";
    private final McpJsonMapper jsonMapper;
    private final String messageEndpoint;
    private final String baseUrl;
    private final RouterFunction<ServerResponse> routerFunction;
    private McpServerSession.Factory sessionFactory;
    protected final ConcurrentHashMap<String, GraphDBMcpServerSession> sessions = new ConcurrentHashMap();
    private final McpTransportContextExtractor<ServerRequest> contextExtractor;
    private volatile boolean isClosing = false;
    private KeepAliveScheduler keepAliveScheduler;
    private ScheduledExecutorService idleCleanupExecutor;
    private final Duration idleSessionTimeout = Duration.ofMinutes(Config.getPropertyAsLong((String)"graphdb.mcp.server.idle.sessions.timeout", (long)30L));
    private final int maxSessions = Config.getPropertyAsInt((String)"graphdb.mcp.server.max.sessions", (int)100);
    private final AtomicInteger sessionCount = new AtomicInteger();

    private GraphDBWebMvcSseServerTransportProvider(McpJsonMapper jsonMapper, String baseUrl, String messageEndpoint, String sseEndpoint, Duration keepAliveInterval, McpTransportContextExtractor<ServerRequest> contextExtractor) {
        Assert.notNull((Object)jsonMapper, (String)"McpJsonMapper must not be null");
        Assert.notNull((Object)baseUrl, (String)"Message base URL must not be null");
        Assert.notNull((Object)messageEndpoint, (String)"Message endpoint must not be null");
        Assert.notNull((Object)sseEndpoint, (String)"SSE endpoint must not be null");
        Assert.notNull(contextExtractor, (String)"Context extractor must not be null");
        this.jsonMapper = jsonMapper;
        this.baseUrl = baseUrl;
        this.messageEndpoint = messageEndpoint;
        this.contextExtractor = contextExtractor;
        this.routerFunction = RouterFunctions.route().GET(sseEndpoint, this::handleSseConnection).POST(this.messageEndpoint, this::handleMessage).build();
        if (keepAliveInterval != Duration.ZERO) {
            this.keepAliveScheduler = KeepAliveScheduler.builder(() -> this.isClosing ? Flux.empty() : Flux.fromIterable(this.sessions.values()).filter(s -> !s.isIdle(this.idleSessionTimeout)).map(GraphDBMcpServerSession::unwrap)).interval(keepAliveInterval).build();
            this.keepAliveScheduler.start();
        }
        this.startPeriodicIdleCleanup();
    }

    public List<String> protocolVersions() {
        return List.of("2024-11-05");
    }

    public void setSessionFactory(McpServerSession.Factory sessionFactory) {
        this.sessionFactory = sessionFactory;
    }

    public Mono<Void> notifyClients(String method, Object params) {
        if (this.sessions.isEmpty()) {
            logger.debug("No active sessions to broadcast message to");
            return Mono.empty();
        }
        logger.debug("Attempting to broadcast message to {} active sessions", (Object)this.sessions.size());
        return Flux.fromIterable(this.sessions.values()).flatMap(session -> session.sendNotification(method, params).doOnError(e -> logger.error("Failed to send message to session {}: {}", (Object)session.getId(), (Object)e.getMessage())).onErrorComplete()).then();
    }

    public Mono<Void> closeGracefully() {
        return Flux.fromIterable(this.sessions.values()).doFirst(() -> {
            this.isClosing = true;
            logger.debug("Initiating graceful shutdown with {} active sessions", (Object)this.sessions.size());
        }).flatMap(GraphDBMcpServerSession::closeGracefully).then().doOnSuccess(v -> {
            logger.debug("Graceful shutdown completed");
            this.sessions.clear();
            if (this.keepAliveScheduler != null) {
                this.keepAliveScheduler.shutdown();
            }
            try {
                this.idleCleanupExecutor.shutdown();
                if (!this.idleCleanupExecutor.awaitTermination(5L, TimeUnit.SECONDS)) {
                    this.idleCleanupExecutor.shutdownNow();
                }
            }
            catch (InterruptedException e) {
                this.idleCleanupExecutor.shutdownNow();
                Thread.currentThread().interrupt();
            }
            catch (Exception e) {
                logger.warn("Scheduler shutdown failed: {}", (Object)e.getMessage(), (Object)e);
            }
        });
    }

    public RouterFunction<ServerResponse> getRouterFunction() {
        return this.routerFunction;
    }

    private ServerResponse handleSseConnection(ServerRequest request) {
        if (this.isClosing) {
            return ServerResponse.status((HttpStatusCode)HttpStatus.SERVICE_UNAVAILABLE).body((Object)"Server is shutting down");
        }
        int overflow = this.sessionCount.get() - this.maxSessions;
        if (overflow > 0) {
            this.evictMostIdleSessions();
        }
        String sessionId = UUID.randomUUID().toString();
        logger.debug("Creating new SSE connection for session: {}", (Object)sessionId);
        String resolvedBaseUrl = this.resolveBaseUrl(request);
        try {
            return ServerResponse.sse(sseBuilder -> {
                sseBuilder.onComplete(() -> {
                    logger.debug("SSE connection completed for session: {}", (Object)sessionId);
                    this.deregisterSession(sessionId);
                });
                sseBuilder.onTimeout(() -> {
                    logger.debug("SSE connection timed out for session: {}", (Object)sessionId);
                    this.deregisterSession(sessionId);
                });
                WebMvcMcpSessionTransport sessionTransport = new WebMvcMcpSessionTransport(sessionId, (ServerResponse.SseBuilder)sseBuilder);
                GraphDBMcpServerSession session = new GraphDBMcpServerSession(this.sessionFactory.create((McpServerTransport)sessionTransport));
                this.sessions.put(sessionId, session);
                this.sessionCount.incrementAndGet();
                try {
                    sseBuilder.id(sessionId).event(ENDPOINT_EVENT_TYPE).data((Object)(resolvedBaseUrl + this.messageEndpoint + "?sessionId=" + sessionId));
                }
                catch (Exception e) {
                    logger.error("Failed to send initial endpoint event: {}", (Object)e.getMessage());
                    sseBuilder.error((Throwable)e);
                }
            }, (Duration)Duration.ZERO);
        }
        catch (Exception e) {
            logger.error("Failed to send initial endpoint event to session {}: {}", (Object)sessionId, (Object)e.getMessage());
            this.deregisterSession(sessionId);
            return ServerResponse.status((HttpStatusCode)HttpStatus.INTERNAL_SERVER_ERROR).build();
        }
    }

    private ServerResponse handleMessage(ServerRequest request) {
        if (this.isClosing) {
            return ServerResponse.status((HttpStatusCode)HttpStatus.SERVICE_UNAVAILABLE).body((Object)"Server is shutting down");
        }
        if (request.param("sessionId").isEmpty()) {
            return ServerResponse.badRequest().body((Object)new McpError((Object)"Session ID missing in message endpoint"));
        }
        String sessionId = (String)request.param("sessionId").get();
        GraphDBMcpServerSession session = this.sessions.get(sessionId);
        if (session == null) {
            return ServerResponse.status((HttpStatusCode)HttpStatus.NOT_FOUND).body((Object)new McpError((Object)("Session not found: " + sessionId)));
        }
        try {
            McpTransportContext transportContext = this.contextExtractor.extract((Object)request);
            String body = (String)request.body(String.class);
            McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage((McpJsonMapper)this.jsonMapper, (String)body);
            if (message instanceof McpSchema.JSONRPCRequest) {
                McpSchema.JSONRPCRequest gsonRPCRequest = (McpSchema.JSONRPCRequest)message;
                SystemStatisticsCollector.getInstance().incrementTtygMcpToolCall();
                if (!session.isDelegateInitialized() && !"initialize".equals(gsonRPCRequest.method())) {
                    session.handle((McpSchema.JSONRPCMessage)new McpSchema.JSONRPCNotification("2.0", "notifications/initialized", Map.of())).block();
                }
            }
            session.handle(message).contextWrite(ctx -> ctx.put((Object)"MCP_TRANSPORT_CONTEXT", (Object)transportContext)).block();
            return ServerResponse.ok().build();
        }
        catch (IOException | IllegalArgumentException e) {
            logger.error("Failed to deserialize message: {}", (Object)e.getMessage());
            return ServerResponse.badRequest().body((Object)new McpError((Object)"Invalid message format"));
        }
        catch (Exception e) {
            logger.error("Error handling message: {}", (Object)e.getMessage());
            return ServerResponse.status((HttpStatusCode)HttpStatus.INTERNAL_SERVER_ERROR).body((Object)new McpError((Object)e.getMessage()));
        }
    }

    String resolveBaseUrl(ServerRequest request) {
        String proxiedHeader = request.headers().asHttpHeaders().getFirst("X-GraphDB-Proxied-From");
        if (StringUtils.isNotEmpty((CharSequence)proxiedHeader)) {
            try {
                return this.extractBaseUri(proxiedHeader.replace("cluster-proxy;", ""));
            }
            catch (Exception e) {
                logger.warn("Failed to parse X-GRAPHDB-PROXIED-FROM header. Falling back to request URI. Cause: {}", (Object)e.getMessage());
            }
        }
        try {
            return this.extractBaseUri(request.uri()).path(request.requestPath().contextPath().value()).build().toUriString();
        }
        catch (Exception e) {
            logger.warn("Failed to resolve base URL from request URI. Falling back to default. Cause: {}", (Object)e.getMessage());
            return this.baseUrl;
        }
    }

    private String extractBaseUri(String uriString) {
        return this.extractBaseUri(URI.create(uriString)).build().toUriString();
    }

    private UriComponentsBuilder extractBaseUri(URI uri) {
        UriComponentsBuilder builder = UriComponentsBuilder.newInstance().scheme(uri.getScheme()).host(uri.getHost());
        int port = uri.getPort();
        if (port != -1) {
            builder.port(port);
        }
        return builder;
    }

    public static Builder builder() {
        return new Builder();
    }

    private void evictMostIdleSessions() {
        this.sessions.values().stream().filter(session -> !session.isClosing()).sorted(Comparator.comparing(GraphDBMcpServerSession::lastActivity)).limit(this.maxSessions / 2).forEach(session -> {
            if (session.tryEvict()) {
                logger.debug("Evicting idle MCP session {} to make room", (Object)session.getId());
                session.close();
            }
        });
    }

    private void startPeriodicIdleCleanup() {
        this.idleCleanupExecutor = Executors.newSingleThreadScheduledExecutor(r -> {
            Thread t = new Thread(r, "mcp-idle-cleanup");
            t.setDaemon(true);
            return t;
        });
        this.idleCleanupExecutor.scheduleAtFixedRate(this::cleanupIdleSessions, this.idleSessionTimeout.toMillis(), this.idleSessionTimeout.toMillis(), TimeUnit.MILLISECONDS);
    }

    private void cleanupIdleSessions() {
        this.sessions.values().stream().filter(session -> session.isIdle(this.idleSessionTimeout) && !session.isClosing()).forEach(session -> {
            if (session.tryEvict()) {
                logger.debug("Closing idle MCP session {}", (Object)session.getId());
                session.close();
            }
        });
    }

    private void deregisterSession(String sessionId) {
        if (this.sessions.remove(sessionId) != null) {
            this.sessionCount.decrementAndGet();
        }
    }

    public static class Builder {
        private McpJsonMapper jsonMapper;
        private String baseUrl = Config.getExternalUrl(null);
        private String messageEndpoint = "/mcp/message";
        private String sseEndpoint = "/mcp/sse";
        private Duration keepAliveInterval = Duration.of(Config.getPropertyAsLong((String)"graphdb.mcp.server.sessions.keepalive.interval", (long)0L), ChronoUnit.MILLIS);
        private McpTransportContextExtractor<ServerRequest> contextExtractor = serverRequest -> McpTransportContext.EMPTY;

        public Builder jsonMapper(McpJsonMapper jsonMapper) {
            Assert.notNull((Object)jsonMapper, (String)"McpJsonMapper must not be null");
            this.jsonMapper = jsonMapper;
            return this;
        }

        public Builder baseUrl(String baseUrl) {
            Assert.notNull((Object)baseUrl, (String)"Base URL must not be null");
            this.baseUrl = baseUrl;
            return this;
        }

        public Builder messageEndpoint(String messageEndpoint) {
            Assert.hasText((String)messageEndpoint, (String)"Message endpoint must not be empty");
            this.messageEndpoint = messageEndpoint;
            return this;
        }

        public Builder sseEndpoint(String sseEndpoint) {
            Assert.hasText((String)sseEndpoint, (String)"SSE endpoint must not be empty");
            this.sseEndpoint = sseEndpoint;
            return this;
        }

        public Builder keepAliveInterval(Duration keepAliveInterval) {
            this.keepAliveInterval = keepAliveInterval;
            return this;
        }

        public Builder contextExtractor(McpTransportContextExtractor<ServerRequest> contextExtractor) {
            Assert.notNull(contextExtractor, (String)"contextExtractor must not be null");
            this.contextExtractor = contextExtractor;
            return this;
        }

        public GraphDBWebMvcSseServerTransportProvider build() {
            if (this.messageEndpoint == null) {
                throw new IllegalStateException("MessageEndpoint must be set");
            }
            return new GraphDBWebMvcSseServerTransportProvider(this.jsonMapper == null ? McpJsonMapper.getDefault() : this.jsonMapper, this.baseUrl, this.messageEndpoint, this.sseEndpoint, this.keepAliveInterval, this.contextExtractor);
        }
    }

    private class WebMvcMcpSessionTransport
    implements McpServerTransport {
        private final String sessionId;
        private final ServerResponse.SseBuilder sseBuilder;
        private final ReentrantLock sseBuilderLock = new ReentrantLock();

        WebMvcMcpSessionTransport(String sessionId, ServerResponse.SseBuilder sseBuilder) {
            this.sessionId = sessionId;
            this.sseBuilder = sseBuilder;
            logger.debug("Session transport {} initialized with SSE builder", (Object)sessionId);
        }

        public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
            return Mono.fromRunnable(() -> {
                this.sseBuilderLock.lock();
                try {
                    String jsonText = GraphDBWebMvcSseServerTransportProvider.this.jsonMapper.writeValueAsString((Object)message);
                    this.sseBuilder.id(this.sessionId).event(GraphDBWebMvcSseServerTransportProvider.MESSAGE_EVENT_TYPE).data((Object)jsonText);
                    logger.debug("Message sent to session {}", (Object)this.sessionId);
                }
                catch (Exception e) {
                    logger.error("Failed to send message to session {}: {}", (Object)this.sessionId, (Object)e.getMessage());
                    this.sseBuilder.error((Throwable)e);
                }
                finally {
                    this.sseBuilderLock.unlock();
                }
            });
        }

        public <T> T unmarshalFrom(Object data, TypeRef<T> typeRef) {
            return (T)GraphDBWebMvcSseServerTransportProvider.this.jsonMapper.convertValue(data, typeRef);
        }

        public Mono<Void> closeGracefully() {
            return Mono.fromRunnable(() -> {
                logger.debug("Closing session transport: {}", (Object)this.sessionId);
                this.sseBuilderLock.lock();
                try {
                    this.sseBuilder.complete();
                    logger.debug("Successfully completed SSE builder for session {}", (Object)this.sessionId);
                }
                catch (Exception e) {
                    logger.warn("Failed to complete SSE builder for session {}: {}", (Object)this.sessionId, (Object)e.getMessage());
                }
                finally {
                    this.sseBuilderLock.unlock();
                }
            });
        }

        public void close() {
            this.sseBuilderLock.lock();
            try {
                this.sseBuilder.complete();
                logger.debug("Successfully completed SSE builder for session {}", (Object)this.sessionId);
            }
            catch (Exception e) {
                logger.warn("Failed to complete SSE builder for session {}: {}", (Object)this.sessionId, (Object)e.getMessage());
            }
            finally {
                this.sseBuilderLock.unlock();
            }
        }
    }
}

