/*
 * Decompiled with CFR 0.152.
 */
package com.ontotext.sparql;

import com.google.common.collect.ImmutableSet;
import com.ontotext.graphql.parser.OperationPostProcessor;
import com.ontotext.models.Operation;
import com.ontotext.models.OperationType;
import com.ontotext.models.extensions.ConfigurationResolver;
import com.ontotext.models.query.Arguments;
import com.ontotext.rbac.SecurityContext;
import com.ontotext.soaas.common.http.HttpServletRequest;
import com.ontotext.soaas.common.http.HttpServletResponse;
import com.ontotext.soaas.common.http.RequestFilterExtension;
import com.ontotext.soaas.plugin.Inject;
import com.ontotext.sparql.SparqlEndpointRequestContext;
import java.lang.invoke.MethodHandles;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.http.HttpHeaders;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SparqlEndpointHeaderCopier
implements RequestFilterExtension,
OperationPostProcessor {
    private static final Logger LOGGER = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
    private static final Set<String> FORBIDDEN_HEADERS;
    private Set<String> passthroughHeaders;
    private String prefix;

    public void doFilter(HttpServletRequest request, HttpServletResponse response) {
        SparqlEndpointRequestContext.setSubQueryResultLimit(request.getHeader("X-GraphDB-SplitQueryLimit"));
        SparqlEndpointRequestContext.setSubQueryFullResultRequest(request.getHeader("X-GraphDB-SplitQueryFullResult"));
        SparqlEndpointRequestContext.setExpandOverSameAs(request.getHeader("X-Expand-Over-Owl-SameAs"));
        SparqlEndpointRequestContext.setIncludeInferred(request.getHeader("X-Include-Inferred"));
        SparqlEndpointRequestContext.setRepository(request.getHeader("X-Repository"));
        if (this.passthroughHeaders != null && !this.passthroughHeaders.isEmpty()) {
            ArrayList<Pair<String, String>> headers = new ArrayList<Pair<String, String>>(this.passthroughHeaders.size());
            for (String header : this.passthroughHeaders) {
                this.copyHeader(request, header, headers);
            }
            if (!headers.isEmpty()) {
                SparqlEndpointRequestContext.setPassthroughHeaders(headers);
            }
        }
    }

    private void copyHeader(HttpServletRequest request, String header, List<Pair<String, String>> headers) {
        String headerValue = request.getHeader(header);
        if (!StringUtils.isNotBlank((CharSequence)headerValue)) {
            return;
        }
        if (this.prefix != null && header.startsWith(this.prefix)) {
            String actualHeader = header.substring(this.prefix.length());
            if (!FORBIDDEN_HEADERS.contains(actualHeader)) {
                headers.add((Pair<String, String>)Pair.of((Object)actualHeader, (Object)headerValue));
            }
        } else {
            headers.add((Pair<String, String>)Pair.of((Object)header, (Object)headerValue));
        }
    }

    public void doAfterFilter(HttpServletRequest request, HttpServletResponse response) {
        SparqlEndpointRequestContext.clearAll();
    }

    public void postProcess(Operation operation, SecurityContext securityContext) {
        Optional expandOverOwlSameAs;
        if (operation.getOperationType() != OperationType.QUERY && !operation.getOperationType().isMutation()) {
            return;
        }
        Arguments arguments = operation.getArguments();
        if (arguments == null) {
            return;
        }
        Optional includeInferred = arguments.getIncludeInferred();
        if (includeInferred.isEmpty()) {
            Boolean headerValue = SparqlEndpointRequestContext.getIncludeInferred();
            if (headerValue == null) {
                headerValue = operation.getSchema().getConfig().getIncludeInferred().orElse(null);
            }
            arguments.setIncludeInferred(headerValue);
        }
        if ((expandOverOwlSameAs = arguments.getExpandOverOwlSameAs()).isEmpty()) {
            Boolean headerValue = SparqlEndpointRequestContext.getExpandOverSameAs();
            if (headerValue == null) {
                headerValue = operation.getSchema().getConfig().getExpandOverOwlSameAs().orElse(null);
            }
            arguments.setExpandOverOwlSameAs(headerValue);
        }
    }

    @Inject
    public void setConfigurationResolver(ConfigurationResolver configurationResolver) {
        if (configurationResolver == null) {
            return;
        }
        String[] config = (String[])configurationResolver.resolve("sparql.endpoint.httpHeadersPassthrough", String[].class, (Object)new String[0]);
        this.passthroughHeaders = Arrays.stream(config).collect(Collectors.toSet());
        this.prefix = StringUtils.trimToNull((String)((String)configurationResolver.resolve("sparql.endpoint.httpHeadersPassthrough.prefix", String.class, null)));
        LinkedHashSet<String> invalidHeaders = new LinkedHashSet<String>();
        Iterator<String> iterator = this.passthroughHeaders.iterator();
        while (iterator.hasNext()) {
            String header;
            String actualHeader = header = iterator.next();
            if (this.prefix != null && actualHeader.startsWith(this.prefix)) {
                actualHeader = actualHeader.substring(this.prefix.length());
            }
            if (!FORBIDDEN_HEADERS.contains(actualHeader)) continue;
            invalidHeaders.add(header);
        }
        for (String invalidHeader : invalidHeaders) {
            if (!this.passthroughHeaders.remove(invalidHeader)) continue;
            LOGGER.warn("The header {} cannot be set as pass-through to a remote SPARQL endpoint. The header is removed from the configuration", (Object)invalidHeader);
        }
    }

    static {
        Field[] fields = HttpHeaders.class.getFields();
        LinkedHashSet<String> headers = new LinkedHashSet<String>();
        for (Field field : fields) {
            try {
                Object value = field.get(null);
                headers.add(Objects.toString(value));
            }
            catch (IllegalAccessException ex) {
                throw new IllegalStateException(ex);
            }
        }
        headers.remove("Authorization");
        headers.remove("Proxy-Authorization");
        headers.remove("Location");
        headers.remove("Host");
        FORBIDDEN_HEADERS = ImmutableSet.copyOf(headers);
    }
}

