/*
 * Decompiled with CFR 0.152.
 */
package com.ontotext.embeddings;

import ai.graphwise.transformer.GraphwiseTransformer;
import ai.graphwise.transformer.InferenceServiceGrpc;
import com.ontotext.Config;
import com.ontotext.embeddings.security.AuthClientInterceptor;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.exception.AuthenticationException;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import io.grpc.Channel;
import io.grpc.ClientInterceptor;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import java.io.Closeable;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

public class GraphwiseTransformerClient
implements EmbeddingModel,
Closeable {
    public static final String MODEL_NAME_PROPERTY = "graphwise.transformer.embedding.model.name";
    public static final String MODEL_NAME_DEFAULT = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2";
    public static final String ADDRESS_PROPERTY = "graphwise.transformer.address";
    public static final String ADDRESS_DEFAULT = "localhost:5050";
    public static final String BATCH_SIZE_PROPERTY = "graphwise.transformer.batch.size";
    public static final int BATCH_SIZE_DEFAULT = 256;
    public static final String AUTH_TOKEN_SECRET_PROPERTY = "graphwise.transformer.auth.token.secret";
    public static final String THREAD_POOL_SIZE_PROPERTY = "graphwise.transformer.thread.pool.size";
    public static final int MAX_MESSAGE_SIZE = 4096000;
    public static final int MESSAGE_FRAMING_OVERHEAD = 32;
    public static final int FLOAT_BYTES = 4;
    private static final String MODEL_NAME = Config.getProperty("graphwise.transformer.embedding.model.name", "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2");
    private static final String ADDRESS = Config.getProperty("graphwise.transformer.address", "localhost:5050");
    private static final int BATCH_SIZE = Config.getPropertyInt("graphwise.transformer.batch.size", 256) * 1024;
    private static final int THREAD_POOL_SIZE = Config.getPropertyInt("graphwise.transformer.thread.pool.size", -1);
    private final ManagedChannel channel = this.buildChannel();
    private final InferenceServiceGrpc.InferenceServiceBlockingStub stub = InferenceServiceGrpc.newBlockingStub((Channel)this.channel);
    private final ExecutorService executor = this.createExecutor();
    private final int embeddingSize = this.dimension() * 4;

    @Override
    public Response<List<Embedding>> embedAll(List<TextSegment> segments) {
        List<List<TextSegment>> batches = this.createBatches(segments);
        List<CompletableFuture> futures = batches.stream().map(batch -> CompletableFuture.supplyAsync(() -> {
            GraphwiseTransformer.SentenceRequest request = GraphwiseTransformer.SentenceRequest.newBuilder().setModelName(MODEL_NAME).addAllTexts(batch.stream().map(TextSegment::text).toList()).build();
            GraphwiseTransformer.SentenceResponse response = this.stub.embedSentence(request);
            return this.toLangchainEmbeddings(response.getEmbeddingsList());
        }, this.executor)).toList();
        try {
            List allEmbeddings = futures.stream().map(CompletableFuture::join).flatMap(Collection::stream).toList();
            return Response.from(allEmbeddings);
        }
        catch (CompletionException e) {
            StatusRuntimeException sre;
            Throwable cause = e.getCause();
            if (cause instanceof StatusRuntimeException && (sre = (StatusRuntimeException)cause).getStatus().getCode() == Status.Code.UNAUTHENTICATED) {
                throw new AuthenticationException("Authentication failed due to: ", cause);
            }
            throw e;
        }
    }

    @Override
    public void close() {
        this.channel.shutdown();
        try {
            if (!this.channel.awaitTermination(5L, TimeUnit.SECONDS)) {
                this.channel.shutdownNow();
            }
        }
        catch (InterruptedException e) {
            this.channel.shutdownNow();
            Thread.currentThread().interrupt();
        }
        finally {
            this.executor.shutdownNow();
        }
    }

    private ManagedChannel buildChannel() {
        String[] parts = ADDRESS.split(":");
        String host = parts[0];
        int port = Integer.parseInt(parts[1]);
        String secret = Config.getProperty(AUTH_TOKEN_SECRET_PROPERTY);
        ManagedChannelBuilder builder = ManagedChannelBuilder.forAddress((String)host, (int)port).usePlaintext();
        if (secret != null && !secret.isEmpty()) {
            builder = builder.intercept(new ClientInterceptor[]{new AuthClientInterceptor(secret)});
        }
        return builder.build();
    }

    private ExecutorService createExecutor() {
        int threadPoolSize = THREAD_POOL_SIZE >= 0 ? THREAD_POOL_SIZE : Runtime.getRuntime().availableProcessors();
        return new ThreadPoolExecutor(Math.min(threadPoolSize, 2), threadPoolSize, 30L, TimeUnit.SECONDS, new LinkedBlockingQueue<Runnable>(threadPoolSize * 2), r -> {
            Thread t = new Thread(r);
            t.setDaemon(true);
            t.setName("graphwise-embedding-" + MODEL_NAME);
            return t;
        }, new ThreadPoolExecutor.CallerRunsPolicy());
    }

    private List<List<TextSegment>> createBatches(List<TextSegment> segments) {
        ArrayList<List<TextSegment>> chunks = new ArrayList<List<TextSegment>>();
        ArrayList<TextSegment> currentBatch = new ArrayList<TextSegment>();
        int currentBytes = 0;
        for (TextSegment segment : segments) {
            int size = segment.text().getBytes(StandardCharsets.UTF_8).length + 32;
            int requestSize = currentBytes + size;
            int responseSize = (currentBatch.size() + 1) * this.embeddingSize + requestSize;
            if (requestSize > BATCH_SIZE || responseSize > 4096000) {
                if (!currentBatch.isEmpty()) {
                    chunks.add(List.copyOf(currentBatch));
                    currentBatch.clear();
                    currentBytes = 0;
                }
                if (size > BATCH_SIZE) {
                    if (size > 4096000) {
                        String trimmed = this.trim(segment.text());
                        segment = new TextSegment(trimmed, segment.metadata());
                    }
                    chunks.add(List.of(segment));
                    continue;
                }
            }
            currentBatch.add(segment);
            currentBytes += size;
        }
        if (!currentBatch.isEmpty()) {
            chunks.add(List.copyOf(currentBatch));
        }
        return chunks;
    }

    private String trim(String string) {
        return string.substring(0, 1000000);
    }

    private List<Embedding> toLangchainEmbeddings(List<GraphwiseTransformer.Embedding> protoEmbeddings) {
        return protoEmbeddings.stream().map(protoEmb -> {
            List<Float> list = protoEmb.getEmbeddingList();
            float[] vector = new float[list.size()];
            for (int i = 0; i < list.size(); ++i) {
                vector[i] = list.get(i).floatValue();
            }
            return new Embedding(vector);
        }).toList();
    }
}

