/*
 * Decompiled with CFR 0.152.
 */
package pitt.search.semanticvectors;

import com.ontotext.trree.sdk.PluginException;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Map;
import java.util.Random;
import java.util.WeakHashMap;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Logger;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.util.BytesRef;
import org.eclipse.rdf4j.query.QueryInterruptedException;
import pitt.search.semanticvectors.BlockingExecutor;
import pitt.search.semanticvectors.CloseableVectorStore;
import pitt.search.semanticvectors.FlagConfig;
import pitt.search.semanticvectors.LuceneUtils;
import pitt.search.semanticvectors.ObjectVector;
import pitt.search.semanticvectors.VectorStore;
import pitt.search.semanticvectors.VectorStoreReader;
import pitt.search.semanticvectors.VectorStoreUtils;
import pitt.search.semanticvectors.VectorStoreWriter;
import pitt.search.semanticvectors.collections.ModifiableVectorStore;
import pitt.search.semanticvectors.collections.VectorStoreFactory;
import pitt.search.semanticvectors.utils.VerbatimLogger;
import pitt.search.semanticvectors.vectors.PermutationUtils;
import pitt.search.semanticvectors.vectors.Vector;
import pitt.search.semanticvectors.vectors.VectorFactory;
import pitt.search.semanticvectors.vectors.VectorType;

public class PSI {
    private static final Logger logger = Logger.getLogger(PSI.class.getCanonicalName());
    private FlagConfig flagConfig;
    private VectorStore elementalPredicateVectors;
    private ModifiableVectorStore semanticItemVectors;
    private ModifiableVectorStore semanticPredicateVectors;
    private ModifiableVectorStore elementalItemVectors;
    private static final String SUBJECT_FIELD = "subject";
    private static final String PREDICATE_FIELD = "predicate";
    private static final String OBJECT_FIELD = "object";
    private static final String PREDICATION_FIELD = "predication";
    private String[] itemFields = new String[]{"subject", "object"};
    private LuceneUtils luceneUtils;
    private int[] predicatePermutation;
    private BlockingExecutor es;
    private Thread shutdownHook;
    private volatile boolean interrupted = false;
    private AtomicBoolean isCreationInterruptedByUser;
    private final int BLOCKING_QUEUE_SIZE = Integer.parseInt(System.getProperty("graphdb.predication.max.generated.tasks", "20"));

    public PSI(FlagConfig flagConfig) {
        this(flagConfig, new AtomicBoolean(false));
    }

    public PSI(FlagConfig flagConfig, AtomicBoolean isCreationInterruptedByUser) {
        this.predicatePermutation = PermutationUtils.getShiftPermutation(flagConfig.vectortype(), flagConfig.dimension(), 1);
        this.isCreationInterruptedByUser = isCreationInterruptedByUser;
    }

    public boolean createIncrementalPSIVectors(FlagConfig flagConfig) throws IOException {
        this.flagConfig = flagConfig;
        this.initialize();
        VectorStoreWriter.writeVectors(flagConfig.elementalvectorfile(), flagConfig, this.elementalItemVectors);
        VectorStoreWriter.writeVectors(flagConfig.elementalpredicatevectorfile(), flagConfig, this.elementalPredicateVectors);
        VerbatimLogger.info("Performing first round of PSI training ...");
        this.trainIncrementalPSIVectors("");
        int trainingCycles = flagConfig.trainingcycles();
        if (trainingCycles > 0) {
            for (int i = 0; i < trainingCycles; ++i) {
                VerbatimLogger.info("Performing next round of PSI training ...");
                this.elementalItemVectors = this.semanticItemVectors;
                this.elementalPredicateVectors = this.semanticPredicateVectors;
                this.trainIncrementalPSIVectors(String.valueOf(trainingCycles));
            }
        }
        if (trainingCycles > 0) {
            VectorStoreUtils.renameTrainedVectorsFile(flagConfig.semanticvectorfile(), flagConfig);
            VectorStoreUtils.renameTrainedVectorsFile(flagConfig.semanticpredicatevectorfile(), flagConfig);
            VectorStoreUtils.renameEntityMapVectorsFile(flagConfig.semanticvectorfile(), flagConfig);
            VectorStoreUtils.renameEntityMapVectorsFile(flagConfig.semanticpredicatevectorfile(), flagConfig);
        }
        this.closeVectorStores();
        this.luceneUtils.closeLuceneDir();
        if (!this.interrupted) {
            logger.info("Done with createIncrementalPSIVectors.");
            Runtime.getRuntime().removeShutdownHook(this.shutdownHook);
            return true;
        }
        return false;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void initialize() throws IOException {
        BytesRef bytes;
        this.registerShutdownHook();
        Random random = new Random();
        if (this.luceneUtils == null) {
            this.luceneUtils = new LuceneUtils(this.flagConfig);
        }
        ModifiableVectorStore inputStore = null;
        this.elementalItemVectors = VectorStoreFactory.getVectorStore(this.flagConfig);
        this.semanticItemVectors = VectorStoreFactory.getVectorStore(this.flagConfig);
        this.elementalPredicateVectors = VectorStoreFactory.getElementalVectorStore(this.flagConfig);
        this.semanticPredicateVectors = VectorStoreFactory.getVectorStore(this.flagConfig);
        if (!this.flagConfig.input_index().equals("")) {
            String inputIndexName = this.flagConfig.input_index();
            File inputDir = new File(new File(this.flagConfig.luceneindexpath()).getParentFile().getParentFile(), inputIndexName);
            if (!inputDir.exists()) {
                throw new PluginException("Specified input index does not exist: " + inputIndexName);
            }
            Object tmpStore = null;
            File[] docvectors = inputDir.listFiles(pathname -> pathname.getName().startsWith("docvectors"));
            if (docvectors.length == 0) {
                throw new PluginException("Could not find a docvector file in the specified path: " + inputDir.getAbsolutePath());
            }
            if (docvectors.length > 1) {
                throw new PluginException("Could not determine which docvector file to use for building the index because multiple docvector files exist" + Arrays.asList(docvectors).toString());
            }
            VectorStoreUtils.VectorStoreFormat format = null;
            File docvector = docvectors[0];
            if (docvector.getName().endsWith(".bin")) {
                format = VectorStoreUtils.VectorStoreFormat.LUCENE;
            } else if (docvector.getName().endsWith(".text")) {
                format = VectorStoreUtils.VectorStoreFormat.TEXT;
            } else {
                throw new PluginException("Unknown type of vectorstore" + docvector.getName());
            }
            FlagConfig config = FlagConfig.getFlagConfig(new String[]{"-indexfileformat", format.toString()});
            try {
                tmpStore = VectorStoreReader.openVectorStore(docvector.getAbsolutePath(), config);
                Enumeration<ObjectVector> allVectors = tmpStore.getAllVectors();
                if (allVectors.hasMoreElements()) {
                    Vector sample = allVectors.nextElement().getVector();
                    if (sample.getVectorType() != VectorType.REAL) {
                        throw new PluginException("Please build the literal index with REAL vectors!");
                    }
                    config.setDimension(sample.getDimension());
                }
                inputStore = VectorStoreFactory.getVectorStore(config);
                allVectors = tmpStore.getAllVectors();
                while (allVectors.hasMoreElements()) {
                    ObjectVector vector = allVectors.nextElement();
                    inputStore.putVector(vector.getObject(), vector.getVector());
                }
            }
            catch (Throwable throwable) {
                VectorStoreUtils.closeVectorStores(new CloseableVectorStore[]{tmpStore});
                throw throwable;
            }
            VectorStoreUtils.closeVectorStores(new CloseableVectorStore[]{tmpStore});
        }
        this.flagConfig.setContentsfields(this.itemFields);
        HashSet<String> addedConcepts = new HashSet<String>();
        int tc = 0;
        for (String fieldName : this.itemFields) {
            BytesRef bytes2;
            Terms terms = this.luceneUtils.getTermsForField(fieldName);
            if (terms == null) {
                throw new NullPointerException(String.format("No terms for field '%s'. Please check that index at '%s' was built correctly for use with PSI.", fieldName, this.flagConfig.luceneindexpath()));
            }
            TermsEnum termsEnum = terms.iterator();
            while ((bytes2 = termsEnum.next()) != null) {
                Term term = new Term(fieldName, bytes2);
                if (!this.luceneUtils.termFilter(term)) {
                    VerbatimLogger.fine("Filtering out term: " + term + "\n");
                    continue;
                }
                if (addedConcepts.contains(term.text())) continue;
                addedConcepts.add(term.text());
                this.semanticItemVectors.putVector(term.text(), VectorFactory.createZeroVector(this.flagConfig.vectortype(), this.flagConfig.dimension()));
                Vector elementalVector = null;
                if (inputStore != null) {
                    elementalVector = inputStore.getVector(term.text());
                }
                if (elementalVector == null) {
                    elementalVector = VectorFactory.generateRandomVector(this.flagConfig.vectortype(), this.flagConfig.dimension(), this.flagConfig.seedlength(), random);
                }
                this.elementalItemVectors.putVector(term.text(), elementalVector);
                if (++tc <= 0 || tc % 10000 != 0 && (tc >= 10000 || tc % 1000 != 0)) continue;
                VerbatimLogger.info("Initialized " + tc + " term vectors ... ");
            }
        }
        Terms predicateTerms = this.luceneUtils.getTermsForField(PREDICATE_FIELD);
        String[] dummyArray = new String[]{PREDICATE_FIELD};
        TermsEnum termsEnum = predicateTerms.iterator();
        while ((bytes = termsEnum.next()) != null) {
            Term term = new Term(PREDICATE_FIELD, bytes);
            if (!this.luceneUtils.termFilter(term, dummyArray, 0, Integer.MAX_VALUE, Integer.MAX_VALUE, 1)) continue;
            this.elementalPredicateVectors.getVector(term.text().trim());
            if (this.flagConfig.trainingcycles() > 0) {
                this.semanticPredicateVectors.putVector(term.text().trim(), VectorFactory.createZeroVector(this.flagConfig.vectortype(), this.flagConfig.dimension()));
            }
            this.elementalPredicateVectors.getVector(term.text().trim() + "-INV");
            if (this.flagConfig.trainingcycles() <= 0) continue;
            this.semanticPredicateVectors.putVector(term.text().trim() + "-INV", VectorFactory.createZeroVector(this.flagConfig.vectortype(), this.flagConfig.dimension()));
        }
        if (inputStore != null) {
            inputStore.close();
        }
    }

    private void trainIncrementalPSIVectors(String iterationTag) throws IOException {
        BytesRef bytes;
        String fieldName = PREDICATION_FIELD;
        Terms allTerms = this.luceneUtils.getTermsForField(fieldName);
        TermsEnum termsEnum = allTerms.iterator();
        AtomicInteger pc = new AtomicInteger(0);
        int MAPS_INITIAL_CAPACITY = (int)(0.75 * (double)allTerms.getDocCount());
        this.es = new BlockingExecutor(this.BLOCKING_QUEUE_SIZE, 2, 2, 0L, TimeUnit.MILLISECONDS);
        Map locks = Collections.synchronizedMap(new WeakHashMap(MAPS_INITIAL_CAPACITY));
        Map bindVectorHash = Collections.synchronizedMap(new WeakHashMap(MAPS_INITIAL_CAPACITY));
        Map invBindVectorHash = Collections.synchronizedMap(new WeakHashMap(MAPS_INITIAL_CAPACITY));
        while ((bytes = termsEnum.next()) != null) {
            Term term = new Term(fieldName, bytes);
            PostingsEnum termDocs = this.luceneUtils.getDocsForTerm(term);
            termDocs.nextDoc();
            Document document = this.luceneUtils.getDoc(termDocs.docID());
            String subject = document.get(SUBJECT_FIELD);
            String predicate = document.get(PREDICATE_FIELD);
            String object = document.get(OBJECT_FIELD);
            if (!(this.elementalItemVectors.containsVector(object) && this.elementalItemVectors.containsVector(subject) && this.elementalPredicateVectors.containsVector(predicate))) {
                logger.fine("skipping predication " + subject + " " + predicate + " " + object);
                continue;
            }
            try {
                this.es.execute(() -> {
                    int currCnt;
                    String larger;
                    String smaller;
                    if (this.interrupted || this.isCreationInterruptedByUser.get()) {
                        return;
                    }
                    Thread.currentThread().setName("psi-index-builder");
                    float sWeight = 1.0f;
                    float oWeight = 1.0f;
                    float pWeight = 1.0f;
                    float predWeight = 1.0f;
                    if (subject.compareTo(object) <= 0) {
                        smaller = subject;
                        larger = object;
                    } else {
                        smaller = object;
                        larger = subject;
                    }
                    Lock firstLock = locks.computeIfAbsent(smaller, v -> new ReentrantLock());
                    Lock secondLock = locks.computeIfAbsent(larger, v -> new ReentrantLock());
                    try {
                        firstLock.lock();
                        secondLock.lock();
                        predWeight = this.luceneUtils.getGlobalTermWeight(new Term(PREDICATE_FIELD, predicate));
                        sWeight = this.luceneUtils.getGlobalTermWeight(new Term(SUBJECT_FIELD, subject));
                        oWeight = this.luceneUtils.getGlobalTermWeight(new Term(OBJECT_FIELD, object));
                        pWeight = this.luceneUtils.getLocalTermWeight(this.luceneUtils.getGlobalTermFreq(term));
                        if (this.flagConfig.termweight().equals((Object)LuceneUtils.TermWeight.SQRT)) {
                            predWeight = 0.0f;
                        }
                        Vector subjectSemanticVector = this.semanticItemVectors.getVector(subject);
                        Vector objectSemanticVector = this.semanticItemVectors.getVector(object);
                        Vector subjectElementalVector = this.elementalItemVectors.getVector(subject);
                        Vector objectElementalVector = this.elementalItemVectors.getVector(object);
                        Vector predicateElementalVector = this.elementalPredicateVectors.getVector(predicate);
                        Vector predicateElementalVectorInv = this.elementalPredicateVectors.getVector(predicate + "-INV");
                        Vector objToAdd = bindVectorHash.computeIfAbsent(this.generateKey(predicate, object), key -> {
                            Vector tmp = objectElementalVector.copy();
                            tmp.bind(predicateElementalVector);
                            return tmp;
                        });
                        subjectSemanticVector.superpose(objToAdd, pWeight * (oWeight + predWeight), null);
                        this.semanticItemVectors.updateVector(subject, subjectSemanticVector);
                        Vector subjToAdd = invBindVectorHash.computeIfAbsent(this.generateKey(predicate, subject), key -> {
                            Vector tmp = subjectElementalVector.copy();
                            tmp.bind(predicateElementalVectorInv);
                            return tmp;
                        });
                        objectSemanticVector.superpose(subjToAdd, pWeight * (sWeight + predWeight), null);
                        this.semanticItemVectors.updateVector(object, objectSemanticVector);
                        if (this.flagConfig.trainingcycles() > 0) {
                            Vector predicateSemanticVector = this.semanticPredicateVectors.getVector(predicate);
                            Vector predicateSemanticVectorInv = this.semanticPredicateVectors.getVector(predicate + "-INV");
                            Vector permutedSubjectElementalVector = VectorFactory.createZeroVector(this.flagConfig.vectortype(), this.flagConfig.dimension());
                            Vector permutedObjectElementalVector = VectorFactory.createZeroVector(this.flagConfig.vectortype(), this.flagConfig.dimension());
                            permutedSubjectElementalVector.superpose(subjectElementalVector, 1.0, this.predicatePermutation);
                            permutedObjectElementalVector.superpose(objectElementalVector, 1.0, this.predicatePermutation);
                            permutedSubjectElementalVector.normalize();
                            permutedObjectElementalVector.normalize();
                            Vector predToAdd = subjectElementalVector.copy();
                            predToAdd.bind(permutedObjectElementalVector);
                            predicateSemanticVector.superpose(predToAdd, sWeight * oWeight, null);
                            this.semanticPredicateVectors.updateVector(predicate, predicateSemanticVector);
                            Vector predToAddInv = objectElementalVector.copy();
                            predToAddInv.bind(permutedSubjectElementalVector);
                            predicateSemanticVectorInv.superpose(predToAddInv, oWeight * sWeight, null);
                            this.semanticPredicateVectors.updateVector(predicate + "-INV", predicateSemanticVectorInv);
                        }
                    }
                    catch (Throwable e) {
                        logger.info(e.getMessage());
                    }
                    finally {
                        secondLock.unlock();
                        firstLock.unlock();
                    }
                    if ((currCnt = pc.incrementAndGet()) % 100000 == 0) {
                        logger.info("Processed " + currCnt + " unique predications ...");
                    }
                });
            }
            catch (RejectedExecutionException rejectedExecutionException) {
                // empty catch block
            }
            if (!this.isCreationInterruptedByUser.get()) continue;
            this.shutdown();
            throw new QueryInterruptedException("Transaction was aborted by the user");
        }
        this.es.shutdown();
        if (!this.interrupted) {
            Enumeration<ObjectVector> e = this.semanticItemVectors.getAllVectors();
            while (e.hasMoreElements()) {
                e.nextElement().getVector().normalize();
            }
            e = this.semanticPredicateVectors.getAllVectors();
            while (e.hasMoreElements()) {
                e.nextElement().getVector().normalize();
            }
            VectorStoreWriter.writeVectors(this.flagConfig.semanticvectorfile() + iterationTag, this.flagConfig, this.semanticItemVectors);
            if (this.flagConfig.trainingcycles() > 0) {
                VectorStoreWriter.writeVectors(this.flagConfig.semanticpredicatevectorfile() + iterationTag, this.flagConfig, this.semanticPredicateVectors);
            }
            VerbatimLogger.info("Finished writing this round of semantic item and predicate vectors.\n");
        }
    }

    protected void shutdown() {
        logger.info("Shutting down PSI");
        if (this.shutdownHook != null) {
            try {
                this.closeVectorStores();
                this.es.shutdownNow();
            }
            catch (IllegalStateException illegalStateException) {
            }
            catch (InterruptedException e) {
                throw new PluginException("Couldn't terminate process");
            }
            finally {
                this.luceneUtils.closeLuceneDir();
                Runtime.getRuntime().removeShutdownHook(this.shutdownHook);
            }
        }
    }

    protected void registerShutdownHook() {
        this.shutdownHook = new Thread(() -> {
            logger.info("Interrupting building index");
            this.interrupted = true;
            try {
                this.closeVectorStores();
                this.es.shutdownNow();
            }
            catch (InterruptedException e) {
                throw new PluginException("Couldn't terminate process");
            }
        });
        Runtime.getRuntime().addShutdownHook(this.shutdownHook);
    }

    private String generateKey(String first, String second) {
        return first + "_" + second;
    }

    private void closeVectorStores() {
        if (this.elementalItemVectors != null) {
            this.elementalItemVectors.close();
        }
        if (this.semanticItemVectors != null) {
            this.semanticItemVectors.close();
        }
        if (this.semanticPredicateVectors != null) {
            this.semanticPredicateVectors.close();
        }
    }

    public static void main(String[] args) throws IllegalArgumentException, IOException {
        FlagConfig flagConfig = FlagConfig.getFlagConfig(args);
        args = flagConfig.remainingArgs;
        if (flagConfig.luceneindexpath().isEmpty()) {
            throw new IllegalArgumentException("-luceneindexpath argument must be provided.");
        }
        VerbatimLogger.info("Building PSI model from index in: " + flagConfig.luceneindexpath() + "\n");
        VerbatimLogger.info("Minimum frequency = " + flagConfig.minfrequency() + "\n");
        VerbatimLogger.info("Maximum frequency = " + flagConfig.maxfrequency() + "\n");
        VerbatimLogger.info("Number non-alphabet characters = " + flagConfig.maxnonalphabetchars() + "\n");
        new PSI(flagConfig).createIncrementalPSIVectors(flagConfig);
    }
}

