/*
 * Decompiled with CFR 0.152.
 */
package com.ontotext.forest.graphql.cluster;

import com.ontotext.forest.graphql.GraphDBGraphQLRequestContext;
import com.ontotext.forest.graphql.cluster.RpcGraphQLQueryClient;
import com.ontotext.graphdb.raft.ClusterGroup;
import com.ontotext.graphdb.raft.NodeState;
import com.ontotext.graphdb.raft.grpc.Data;
import com.ontotext.graphdb.raft.grpc.GraphQLQuery;
import com.ontotext.graphdb.raft.grpc.RpcNodeClient;
import com.ontotext.graphdb.raft.grpc.TrackRecordData;
import com.ontotext.graphdb.raft.observe.RaftObserver;
import com.ontotext.graphdb.raft.storage.TransactionLog;
import com.ontotext.graphdb.replicationcluster.LocalConsistency;
import com.ontotext.raft.evaluate.ClosableClusterQueryIterator;
import com.ontotext.raft.evaluate.ClosableMonitoredClusterQueryIterator;
import com.ontotext.trree.RepositoryMonitorTrackRecord;
import com.ontotext.trree.RepositoryMonitorTrackRecordHelper;
import com.ontotext.trree.RepositoryMonitorTrackRecordImpl;
import com.ontotext.trree.monitorRepository.MonitorRepositoryConnection;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import jakarta.annotation.Nullable;
import java.lang.invoke.MethodHandles;
import java.util.Iterator;
import java.util.Objects;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;
import org.eclipse.rdf4j.repository.RepositoryConnection;
import org.eclipse.rdf4j.repository.manager.RepositoryManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GraphQLLoadBalancer
implements RaftObserver {
    private static final Logger LOGGER = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
    private final Queue<RpcGraphQLQueryClient> queryClients;
    private final AtomicInteger leaderQueryCount;
    private final AtomicLong leaderTotalQueryCnt;
    private final Supplier<Long> lastValidTransactionSupplier;
    private final ClusterGroup clusterGroup;
    private final RepositoryManager repositoryManager;

    public GraphQLLoadBalancer(ClusterGroup clusterGroup, RepositoryManager repositoryManager) {
        this.repositoryManager = repositoryManager;
        this.queryClients = new ConcurrentLinkedDeque<RpcGraphQLQueryClient>();
        this.lastValidTransactionSupplier = () -> ((TransactionLog)clusterGroup.getTransactionLog()).getLastValidLog();
        for (RpcNodeClient nodeClient : clusterGroup) {
            this.queryClients.add(new RpcGraphQLQueryClient(nodeClient));
        }
        this.leaderQueryCount = new AtomicInteger(0);
        this.leaderTotalQueryCnt = new AtomicLong(0L);
        this.clusterGroup = clusterGroup;
    }

    public void start() {
        for (RpcGraphQLQueryClient queryClient : this.queryClients) {
            queryClient.init();
        }
    }

    public void shutdown() {
        for (RpcGraphQLQueryClient queryClient : this.queryClients) {
            queryClient.shutdown();
        }
    }

    public void incrementQueryCount() {
        int count = this.leaderQueryCount.incrementAndGet();
        this.leaderTotalQueryCnt.incrementAndGet();
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("Evaluating query on leader node. Total queries running {}", (Object)count);
        }
    }

    public void decrementQueryCount() {
        int count = this.leaderQueryCount.decrementAndGet();
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("Finished evaluating query on leader node. Total queries running {}", (Object)count);
        }
    }

    public Iterator<Data> evaluateQuery(GraphQLQuery query, LocalConsistency consistency) {
        RpcGraphQLQueryClient executableClient = this.pickBalancingClient(consistency);
        if (executableClient == null) {
            return null;
        }
        GraphQLLoadBalancer.logEvaluatingQueryMessage("graphql", executableClient);
        RepositoryConnection connection = this.repositoryManager.getRepository(query.getRepository()).getConnection();
        if (connection instanceof MonitorRepositoryConnection) {
            TrackRecord currentTrackRecord = this.registerTrackRecord(executableClient, connection, query);
            return new GraphQLClosableMonitoredClusterQueryIterator(executableClient.evaluateGraphQL(this.buildTrackedQuery(query, currentTrackRecord)), currentTrackRecord, connection);
        }
        connection.close();
        try {
            return new ClosableClusterQueryIterator(executableClient.evaluateGraphQL(query));
        }
        catch (StatusRuntimeException sre) {
            return (Iterator)this.onQueryException(executableClient, sre);
        }
    }

    private GraphQLQuery buildTrackedQuery(GraphQLQuery query, TrackRecord currentTrackRecord) {
        return GraphQLQuery.newBuilder((GraphQLQuery)query).setTrackAlias(currentTrackRecord.trackAlias).build();
    }

    @Nullable
    private <E> E onQueryException(RpcGraphQLQueryClient executableClient, StatusRuntimeException sre) {
        if (sre.getStatus() == Status.FAILED_PRECONDITION) {
            LOGGER.warn("Node {} was unable to execute query due to: {}", (Object)executableClient.getAddress(), (Object)sre.getStatus().getDescription());
            return null;
        }
        throw sre;
    }

    private TrackRecord registerTrackRecord(RpcGraphQLQueryClient stopQueryClient, RepositoryConnection repositoryCon, GraphQLQuery query) {
        RepositoryMonitorTrackRecordHelper trackRecordHelper = ((MonitorRepositoryConnection)repositoryCon).getSailConnectionImpl().getTrackRecordHelper();
        RepositoryMonitorTrackRecordImpl trackRecord = trackRecordHelper.getTrackRecord();
        trackRecord.setType(RepositoryMonitorTrackRecord.Type.GRAPHQL);
        String trackAlias = trackRecord.getTrackAlias();
        TrackRecordData trackRecordData = TrackRecordData.newBuilder().setRepoId(query.getRepository()).setTrackAlias(trackAlias).build();
        RepositoryMonitorTrackRecordImpl.Aborter aborter = () -> {
            try {
                if (stopQueryClient.abortQuery(trackRecordData)) {
                    LOGGER.info("Successfully canceled query {}", (Object)GraphDBGraphQLRequestContext.getGraphqlRequest());
                }
            }
            catch (RuntimeException re) {
                LOGGER.warn("Abort query failed with ", (Throwable)re);
            }
            finally {
                trackRecordHelper.closeTrackRecord(trackRecord);
                repositoryCon.close();
            }
        };
        trackRecordHelper.registerTrackRecordQuery(aborter);
        trackRecordHelper.registerTrackRecordSparqlString(GraphDBGraphQLRequestContext.getGraphqlRequest());
        return new TrackRecord(trackRecordHelper, trackRecord, trackAlias);
    }

    private RpcGraphQLQueryClient pickBalancingClient(LocalConsistency consistency) {
        RpcGraphQLQueryClient pickedClient = null;
        long lastLogIndex = -1L;
        if (consistency == LocalConsistency.LAST_COMMITTED) {
            lastLogIndex = this.lastValidTransactionSupplier.get();
        }
        for (RpcGraphQLQueryClient newClient : this.queryClients) {
            if (!this.shouldSwapClients(pickedClient, newClient, lastLogIndex)) continue;
            pickedClient = newClient;
        }
        if (pickedClient != null) {
            int leaderQueryCnt;
            boolean clientHasMoreActiveQueries;
            int clientQueryCnt = pickedClient.getQueryCount();
            boolean bl = clientHasMoreActiveQueries = clientQueryCnt > (leaderQueryCnt = this.leaderQueryCount.get());
            if (clientHasMoreActiveQueries) {
                if (LOGGER.isDebugEnabled()) {
                    LOGGER.debug("Going to handle query locally. remote(a={}, q={}, total={}), local(q={}, total={})", new Object[]{pickedClient.getAddress(), clientQueryCnt, pickedClient.getTotalQueryCount(), leaderQueryCnt, this.leaderTotalQueryCnt.get()});
                }
                return null;
            }
        }
        if (pickedClient == null && LOGGER.isDebugEnabled()) {
            LOGGER.debug("Unable to pick client, handling locally");
        }
        return pickedClient;
    }

    private boolean shouldSwapClients(RpcGraphQLQueryClient currentClient, RpcGraphQLQueryClient newClient, long minLogIndex) {
        return this.isClientUpdated(newClient, minLogIndex) && (currentClient == null || this.isClientLessUsed(currentClient, newClient));
    }

    private boolean isClientUpdated(RpcGraphQLQueryClient client, long minLogIndex) {
        return client.fetchStatus() == RpcNodeClient.Status.IN_SYNC && client.getLastLogIndex() >= minLogIndex;
    }

    private boolean isClientLessUsed(RpcGraphQLQueryClient currentClient, RpcGraphQLQueryClient newClient) {
        long newQueryCnt;
        long currentQueryCnt = currentClient.getQueryCount();
        if (currentQueryCnt > (newQueryCnt = (long)newClient.getQueryCount())) {
            return true;
        }
        if (currentQueryCnt == newQueryCnt) {
            return currentClient.getTotalQueryCount() > newClient.getTotalQueryCount();
        }
        return false;
    }

    public void update(NodeState state) {
    }

    public void update(@Nullable String leaderRpcAddress, String leaderHttpAddress) {
    }

    public void update(long term) {
    }

    public void nodeAdded(String rpcAddress, String httpAddress) {
        if (!this.clusterGroup.getCurrentAddress().equals(rpcAddress)) {
            RpcNodeClient rpcNode = this.clusterGroup.getClusterRpcNode(rpcAddress);
            assert (rpcNode != null) : "Cluster group is not updated with the latest changes";
            RpcGraphQLQueryClient queryClient = new RpcGraphQLQueryClient(rpcNode);
            queryClient.init();
            this.queryClients.add(queryClient);
        }
    }

    public void nodeRemoved(String rpcAddress, String httpAddress) {
        this.queryClients.removeIf(client -> {
            if (Objects.equals(client.getAddress(), rpcAddress)) {
                client.shutdown();
                return true;
            }
            return false;
        });
    }

    private static void logEvaluatingQueryMessage(String queryType, RpcGraphQLQueryClient client) {
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("Evaluating {} query on node {}. (q={}, total={})", new Object[]{queryType, client.getAddress(), client.getQueryCount(), client.getTotalQueryCount()});
        }
    }

    private static class TrackRecord {
        final RepositoryMonitorTrackRecordHelper recordHelper;
        final RepositoryMonitorTrackRecordImpl currentRecord;
        final String trackAlias;

        private TrackRecord(RepositoryMonitorTrackRecordHelper trackRecordHelper, RepositoryMonitorTrackRecordImpl trackRecord, String trackAlias) {
            this.recordHelper = trackRecordHelper;
            this.currentRecord = trackRecord;
            this.trackAlias = trackAlias;
        }
    }

    private static class GraphQLClosableMonitoredClusterQueryIterator<T>
    extends ClosableMonitoredClusterQueryIterator<T> {
        private final RepositoryConnection connection;

        public GraphQLClosableMonitoredClusterQueryIterator(Iterator delegate, TrackRecord trackRecord, RepositoryConnection connection) {
            super(delegate, trackRecord.recordHelper, trackRecord.currentRecord);
            this.connection = connection;
        }

        public void close() throws Exception {
            try {
                super.close();
            }
            finally {
                this.connection.close();
            }
        }
    }
}

