/*
 * Decompiled with CFR 0.152.
 */
package com.ontotext.forest.gpt.ttyg.tools;

import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.annotations.VisibleForTesting;
import com.ontotext.forest.gpt.ttyg.AgentInstructions;
import com.ontotext.forest.gpt.ttyg.RDFUtil;
import com.ontotext.forest.gpt.ttyg.ToolCallContext;
import com.ontotext.forest.gpt.ttyg.exceptions.ToolConfigException;
import com.ontotext.forest.gpt.ttyg.tools.BaseTool;
import com.ontotext.forest.gpt.ttyg.tools.ParameterDefinition;
import com.ontotext.forest.gpt.ttyg.tools.SPARQLWellKnownOntology;
import com.ontotext.forest.gpt.ttyg.tools.ToolResponse;
import com.ontotext.forest.gpt.ttyg.tools.ToolType;
import com.ontotext.graphdb.configs.SystemConfig;
import java.io.IOException;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.eclipse.rdf4j.model.IRI;
import org.eclipse.rdf4j.model.Namespace;
import org.eclipse.rdf4j.model.Resource;
import org.eclipse.rdf4j.model.Statement;
import org.eclipse.rdf4j.model.Value;
import org.eclipse.rdf4j.query.MalformedQueryException;
import org.eclipse.rdf4j.query.Query;
import org.eclipse.rdf4j.query.QueryResult;
import org.eclipse.rdf4j.query.algebra.QueryModelVisitor;
import org.eclipse.rdf4j.query.algebra.Service;
import org.eclipse.rdf4j.query.algebra.TupleExpr;
import org.eclipse.rdf4j.query.algebra.Var;
import org.eclipse.rdf4j.query.algebra.helpers.AbstractQueryModelVisitor;
import org.eclipse.rdf4j.query.parser.sparql.ast.VisitorException;
import org.eclipse.rdf4j.repository.RepositoryConnection;
import org.eclipse.rdf4j.repository.RepositoryResult;
import org.eclipse.rdf4j.repository.sail.SailQuery;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SPARQLQueryTool
extends BaseTool {
    private static final String METADATA_PREFIX = "prefix.";
    public static final String SPARQL_TOOL_DESCRIPTION = "Query GraphDB by SPARQL SELECT, CONSTRUCT or DESCRIBE and return result.";
    private static final Logger log = LoggerFactory.getLogger(SPARQLQueryTool.class);
    private String ontologyGraph;
    private String ontologyQuery;
    private boolean addMissingNamespaces;
    private volatile String ontology;

    public SPARQLQueryTool() {
        super(ToolType.SPARQL_QUERY);
    }

    @VisibleForTesting
    public SPARQLQueryTool(String ontologyGraph, String ontologyQuery) {
        this();
        this.setEnabled(true);
        this.setOntologyGraph(ontologyGraph);
        this.setOntologyQuery(ontologyQuery);
    }

    @Override
    public void setFromLegacyUiObject(JsonNode node) {
        this.setEnabled(true);
        this.setOntologyGraph(node.path("ontologyGraph").asText(null));
        this.setOntologyQuery(node.path("sparqlQuery").asText(null));
        this.setAddMissingNamespaces(node.path("addMissingNamespaces").asBoolean(false));
    }

    @Override
    public void writeToLegacyUiObject(JsonGenerator gen) throws IOException {
        gen.writeStringField("method", this.getType().getLegacyUiMethodKey());
        gen.writeStringField("ontologyGraph", this.getOntologyGraph());
        gen.writeStringField("sparqlQuery", this.getOntologyQuery());
        gen.writeBooleanField("addMissingNamespaces", this.isAddMissingNamespaces());
    }

    public String getOntologyGraph() {
        return this.ontologyGraph;
    }

    void setOntologyGraph(String ontologyGraph) {
        this.ontologyGraph = ontologyGraph;
    }

    public String getOntologyQuery() {
        return this.ontologyQuery;
    }

    void setOntologyQuery(String ontologyQuery) {
        this.ontologyQuery = ontologyQuery;
    }

    public boolean isAddMissingNamespaces() {
        return this.addMissingNamespaces;
    }

    public void setAddMissingNamespaces(boolean addMissingNamespaces) {
        this.addMissingNamespaces = addMissingNamespaces;
    }

    @Override
    public void validate(ToolCallContext agentRepository) {
        if (StringUtils.isBlank((CharSequence)this.ontologyGraph) && StringUtils.isBlank((CharSequence)this.ontologyQuery)) {
            throw new ToolConfigException("SPARQL method requires an ontology graph or query.");
        }
        this.extractOntology(agentRepository.connectionInstance());
    }

    @Override
    public String getDescription() {
        return SPARQL_TOOL_DESCRIPTION;
    }

    @Override
    public ParameterDefinition getParameterSchema() {
        return ParameterDefinition.singleParameter("query", "SPARQL query");
    }

    @Override
    public String getNativeQuery(Map<String, Object> parameters) {
        return this.fixSparqlQuery(super.getRawQuery(parameters));
    }

    @Override
    public ToolResponse call(Map<String, Object> parameters, ToolCallContext agentRepository) {
        RepositoryConnection connection = agentRepository.connectionInstance();
        Map<Object, Object> namespaces = this.addMissingNamespaces ? this.addMissingNamespaces(connection, this.getNativeQuery(parameters)) : new LinkedHashMap();
        String output = this.evaluateParsedQuery(this.prepareAndValidateQuery(connection, this.applyNamespaces(this.getNativeQuery(parameters), namespaces)), connection, 0);
        return new ToolResponse(output, namespaces);
    }

    String applyNamespaces(String query, Map<String, String> namespaces) {
        String prefixes = namespaces.entrySet().stream().filter(e -> ((String)e.getKey()).startsWith(METADATA_PREFIX)).map(e -> "PREFIX " + ((String)e.getKey()).substring(METADATA_PREFIX.length()) + ": <" + (String)e.getValue() + ">\n").map(p -> "# TTYG: Automatically added missing prefix\n" + p).collect(Collectors.joining());
        if (!prefixes.isEmpty()) {
            return prefixes + "\n" + query;
        }
        return query;
    }

    Map<String, String> addMissingNamespaces(RepositoryConnection connection, String query) {
        if (!this.addMissingNamespaces) {
            return Map.of();
        }
        int numTries = 16;
        LinkedHashMap<String, String> namespaces = new LinkedHashMap<String, String>();
        try {
            Set<String> undeclaredPrefixes = this.getUndeclaredQNamePrefixes(query);
            for (String prefix : undeclaredPrefixes) {
                String namespace = connection.getNamespace(prefix);
                if (namespace == null || namespaces.containsKey(METADATA_PREFIX + prefix)) continue;
                namespaces.put(METADATA_PREFIX + prefix, namespace);
            }
        }
        catch (Exception e) {
            log.error(e.getMessage(), (Throwable)e);
        }
        while (true) {
            try {
                String fixedQuery = this.applyNamespaces(query, namespaces);
                connection.prepareQuery(fixedQuery);
                return namespaces;
            }
            catch (MalformedQueryException e) {
                String prefix;
                String namespace;
                int indexOfQ;
                String msg;
                Throwable cause = e.getCause();
                if (cause instanceof VisitorException && (msg = cause.getMessage()) != null && msg.startsWith("QName '") && (indexOfQ = (msg = msg.substring("QName '".length())).indexOf(58)) != -1 && (namespace = connection.getNamespace(prefix = msg.substring(0, indexOfQ))) != null) {
                    if (namespaces.containsKey(METADATA_PREFIX + prefix)) continue;
                    namespaces.put(METADATA_PREFIX + prefix, namespace);
                    if (--numTries > 0) continue;
                }
                throw e;
            }
            break;
        }
    }

    private Set<String> getUndeclaredQNamePrefixes(String query) {
        HashSet<String> declaredPrefixes = new HashSet<String>();
        Pattern declaredPattern = Pattern.compile("PREFIX\\s+([a-zA-Z0-9_\\-]+):\\s+<[^>]+>", 34);
        Matcher declaredMatcher = declaredPattern.matcher(query);
        while (declaredMatcher.find()) {
            declaredPrefixes.add(declaredMatcher.group(1));
        }
        HashSet<String> usedPrefixes = new HashSet<String>();
        Pattern qnamePattern = Pattern.compile("([a-zA-Z0-9_\\-]+):([a-zA-Z0-9_\\-]+)", 2);
        Matcher qnameMatcher = qnamePattern.matcher(query);
        while (qnameMatcher.find()) {
            String prefix = qnameMatcher.group(1);
            if (prefix.equalsIgnoreCase("a")) continue;
            usedPrefixes.add(prefix);
        }
        usedPrefixes.removeAll(declaredPrefixes);
        return usedPrefixes;
    }

    public void applyOntologyInstructions(AgentInstructions agentInstructions, RepositoryConnection connection) {
        String ontology = this.extractOntology(connection);
        if (StringUtils.isNotBlank((CharSequence)ontology)) {
            if (this.isEnabled()) {
                agentInstructions.setSparqlInstructions(SystemConfig.getSparqlInstructions());
                if (ontology.length() > 256 || ontology.startsWith("@")) {
                    agentInstructions.setOntologyIntroduction(SystemConfig.getOntologyIntroduction());
                    agentInstructions.setOntology(ontology);
                } else {
                    agentInstructions.setOntologyIntroduction(SystemConfig.getOntologyIntroduction() + " " + ontology);
                    agentInstructions.setOntology("");
                }
            } else if (this.ontologyGraph.startsWith(SPARQLWellKnownOntology.DISABLED.iri)) {
                agentInstructions.setOntologyIntroduction("");
                agentInstructions.setOntology(ontology);
            }
        }
    }

    public String extractOntology(RepositoryConnection connection) {
        String ontology = this.extractWellKnownOntology();
        if (ontology == null) {
            ontology = this.extractOntologyFromRepository(connection);
        }
        if (ontology != null) {
            ontology = ontology.trim();
        }
        return ontology;
    }

    public String extractOntologyFromRepository(RepositoryConnection connection) {
        boolean isGraph = false;
        if (this.ontology == null) {
            int serializeMaxSize = 256000;
            if (StringUtils.isNotBlank((CharSequence)this.ontologyGraph)) {
                isGraph = true;
                this.ontology = RDFUtil.serializeRdf((RepositoryResult<Statement>)connection.getStatements(null, null, null, new Resource[]{connection.getValueFactory().createIRI(this.ontologyGraph)}), (Iterable<Namespace>)connection.getNamespaces(), false, 0, serializeMaxSize);
            } else {
                this.ontology = RDFUtil.serializeRdf((QueryResult<Statement>)connection.prepareGraphQuery(this.ontologyQuery).evaluate(), (Iterable<Namespace>)connection.getNamespaces(), false, 0, serializeMaxSize);
            }
        }
        if (this.isEnabled() && StringUtils.isBlank((CharSequence)this.ontology)) {
            throw new ToolConfigException(isGraph ? "There is no data in the ontology named graph." : "The ontology query did not return any data.");
        }
        return this.ontology;
    }

    String extractWellKnownOntology() {
        SPARQLWellKnownOntology wko = SPARQLWellKnownOntology.fromIri(this.ontologyGraph);
        if (wko != null) {
            if (wko == SPARQLWellKnownOntology.DISABLED) {
                this.setEnabled(false);
            }
            return wko.getInstruction(this.ontologyGraph);
        }
        return null;
    }

    @VisibleForTesting
    String fixSparqlQuery(String query) {
        return query;
    }

    private Query prepareAndValidateQuery(RepositoryConnection connection, String query) {
        Query parsedQuery = connection.prepareQuery(query);
        if (parsedQuery instanceof SailQuery) {
            TupleExpr tupleExpr = ((SailQuery)parsedQuery).getParsedQuery().getTupleExpr();
            final HashSet uniqueIRIs = new HashSet();
            final AtomicBoolean serviceFound = new AtomicBoolean(false);
            try {
                tupleExpr.visit((QueryModelVisitor)new AbstractQueryModelVisitor<RuntimeException>(this){

                    public void meet(Var varNode) {
                        Value val = varNode.getValue();
                        if (val instanceof IRI && !val.stringValue().startsWith("http://www.w3.org/2001/XMLSchema#") && !val.stringValue().startsWith("http://www.ontotext.com/")) {
                            uniqueIRIs.add((IRI)val);
                        }
                        super.meet(varNode);
                    }

                    public void meet(Service service) {
                        serviceFound.compareAndSet(false, true);
                        super.meet(service);
                    }
                });
                LinkedHashSet<String> incorrectIris = new LinkedHashSet<String>();
                for (IRI iri : uniqueIRIs) {
                    if (connection.hasStatement((Resource)iri, null, null, true, new Resource[0]) || connection.hasStatement(null, null, (Value)iri, true, new Resource[0]) || connection.hasStatement(null, iri, null, true, new Resource[0]) || connection.hasStatement(null, null, null, true, new Resource[]{iri})) continue;
                    incorrectIris.add(iri.toString());
                }
                if (!incorrectIris.isEmpty() && !serviceFound.get()) {
                    throw new IllegalArgumentException("The following IRIs are not used in the data stored in GraphDB: " + String.join((CharSequence)", ", incorrectIris));
                }
            }
            catch (MalformedQueryException e) {
                throw new IllegalArgumentException(e);
            }
        }
        return parsedQuery;
    }

    @Override
    public final boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (!(o instanceof SPARQLQueryTool)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        SPARQLQueryTool tool = (SPARQLQueryTool)o;
        return Objects.equals(this.ontologyGraph, tool.ontologyGraph) && Objects.equals(this.ontologyQuery, tool.ontologyQuery);
    }

    @Override
    public int hashCode() {
        int result = super.hashCode();
        result = 31 * result + Objects.hashCode(this.ontologyGraph);
        result = 31 * result + Objects.hashCode(this.ontologyQuery);
        return result;
    }
}

