/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.sparse.query;

import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Callable;
import lombok.Generated;
import lombok.NonNull;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.FilteredDocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.opensearch.neuralsearch.sparse.cache.ForwardIndexCache;
import org.opensearch.neuralsearch.sparse.data.SparseVector;
import org.opensearch.neuralsearch.sparse.query.SparseQueryContext;
import org.opensearch.neuralsearch.sparse.query.SparseQueryWeight;

public class SparseVectorQuery
extends Query {
    @NonNull
    private final SparseVector queryVector;
    @NonNull
    private final SparseQueryContext queryContext;
    @NonNull
    private final String fieldName;
    @NonNull
    private final Query fallbackQuery;
    private final Query filter;
    private Map<Object, BitSet> filterResults;

    public String toString(String field) {
        return field;
    }

    public void visit(QueryVisitor visitor) {
        if (visitor.acceptField(this.getFieldName())) {
            visitor.visitLeaf((Query)this);
        }
    }

    public boolean equals(Object obj) {
        if (obj == null) {
            return false;
        }
        if (((Object)((Object)this)).getClass() != obj.getClass()) {
            return false;
        }
        SparseVectorQuery otherQuery = (SparseVectorQuery)((Object)obj);
        if (!this.queryContext.equals(otherQuery.queryContext)) {
            return false;
        }
        if (!this.fieldName.equals(otherQuery.fieldName)) {
            return false;
        }
        if (!this.fallbackQuery.equals((Object)otherQuery.getFallbackQuery())) {
            return false;
        }
        if (this.filter == null != (otherQuery.getFilter() == null)) {
            return false;
        }
        if (this.filter != null && !this.filter.equals((Object)otherQuery.getFilter())) {
            return false;
        }
        return this.queryVector.equals(otherQuery.getQueryVector());
    }

    public int hashCode() {
        return Objects.hash(this.queryVector, this.queryContext, this.fieldName, this.fallbackQuery, this.filter);
    }

    public Query rewrite(IndexSearcher indexSearcher) throws IOException {
        IndexReader reader = indexSearcher.getIndexReader();
        Weight filterWeight = this.createFilterWeight(indexSearcher);
        if (filterWeight == null) {
            return this;
        }
        TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
        List leafReaderContexts = reader.leaves();
        ArrayList<Callable<Map.Entry>> tasks = new ArrayList<Callable<Map.Entry>>(leafReaderContexts.size());
        for (LeafReaderContext context : leafReaderContexts) {
            tasks.add(() -> this.runFilter(context, filterWeight));
        }
        Map.Entry[] results = (Map.Entry[])taskExecutor.invokeAll(tasks).toArray(Map.Entry[]::new);
        this.filterResults = new HashMap<Object, BitSet>();
        for (Map.Entry filterResult : results) {
            if (filterResult == null) continue;
            this.filterResults.put(filterResult.getKey(), (BitSet)filterResult.getValue());
        }
        return this;
    }

    private Map.Entry<Object, BitSet> runFilter(LeafReaderContext ctx, Weight filterWeight) throws IOException {
        LeafReader reader = ctx.reader();
        Scorer scorer = filterWeight.scorer(ctx);
        if (scorer == null) {
            return null;
        }
        return Map.entry(ctx.id(), this.createBitSet(scorer.iterator(), reader.getLiveDocs(), reader.maxDoc()));
    }

    @VisibleForTesting
    BitSet createBitSet(DocIdSetIterator iterator, final Bits liveDocs, int maxDoc) throws IOException {
        if (liveDocs == null && iterator instanceof BitSetIterator) {
            BitSetIterator bitSetIterator = (BitSetIterator)iterator;
            return bitSetIterator.getBitSet();
        }
        FilteredDocIdSetIterator filterIterator = new FilteredDocIdSetIterator(this, iterator){

            protected boolean match(int doc) {
                return liveDocs == null || liveDocs.get(doc);
            }
        };
        return BitSet.of((DocIdSetIterator)filterIterator, (int)maxDoc);
    }

    private Weight createFilterWeight(IndexSearcher indexSearcher) throws IOException {
        Weight filterWeight = null;
        if (this.filter != null) {
            BooleanQuery booleanQuery = new BooleanQuery.Builder().add(this.filter, BooleanClause.Occur.FILTER).add((Query)new FieldExistsQuery(this.fieldName), BooleanClause.Occur.FILTER).build();
            Query rewritten = indexSearcher.rewrite((Query)booleanQuery);
            filterWeight = indexSearcher.createWeight(rewritten, ScoreMode.COMPLETE_NO_SCORES, 1.0f);
        }
        return filterWeight;
    }

    public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
        return new SparseQueryWeight(this, searcher, scoreMode, boost, ForwardIndexCache.getInstance());
    }

    @Generated
    public static SparseVectorQueryBuilder builder() {
        return new SparseVectorQueryBuilder();
    }

    @NonNull
    @Generated
    public SparseVector getQueryVector() {
        return this.queryVector;
    }

    @NonNull
    @Generated
    public SparseQueryContext getQueryContext() {
        return this.queryContext;
    }

    @NonNull
    @Generated
    public String getFieldName() {
        return this.fieldName;
    }

    @NonNull
    @Generated
    public Query getFallbackQuery() {
        return this.fallbackQuery;
    }

    @Generated
    public Query getFilter() {
        return this.filter;
    }

    @Generated
    public Map<Object, BitSet> getFilterResults() {
        return this.filterResults;
    }

    @Generated
    public SparseVectorQuery(@NonNull SparseVector queryVector, @NonNull SparseQueryContext queryContext, @NonNull String fieldName, @NonNull Query fallbackQuery, Query filter, Map<Object, BitSet> filterResults) {
        Objects.requireNonNull(queryVector, "queryVector is marked non-null but is null");
        Objects.requireNonNull(queryContext, "queryContext is marked non-null but is null");
        Objects.requireNonNull(fieldName, "fieldName is marked non-null but is null");
        Objects.requireNonNull(fallbackQuery, "fallbackQuery is marked non-null but is null");
        this.queryVector = queryVector;
        this.queryContext = queryContext;
        this.fieldName = fieldName;
        this.fallbackQuery = fallbackQuery;
        this.filter = filter;
        this.filterResults = filterResults;
    }

    @Generated
    public static class SparseVectorQueryBuilder {
        @Generated
        private SparseVector queryVector;
        @Generated
        private SparseQueryContext queryContext;
        @Generated
        private String fieldName;
        @Generated
        private Query fallbackQuery;
        @Generated
        private Query filter;
        @Generated
        private Map<Object, BitSet> filterResults;

        @Generated
        SparseVectorQueryBuilder() {
        }

        @Generated
        public SparseVectorQueryBuilder queryVector(@NonNull SparseVector queryVector) {
            Objects.requireNonNull(queryVector, "queryVector is marked non-null but is null");
            this.queryVector = queryVector;
            return this;
        }

        @Generated
        public SparseVectorQueryBuilder queryContext(@NonNull SparseQueryContext queryContext) {
            Objects.requireNonNull(queryContext, "queryContext is marked non-null but is null");
            this.queryContext = queryContext;
            return this;
        }

        @Generated
        public SparseVectorQueryBuilder fieldName(@NonNull String fieldName) {
            Objects.requireNonNull(fieldName, "fieldName is marked non-null but is null");
            this.fieldName = fieldName;
            return this;
        }

        @Generated
        public SparseVectorQueryBuilder fallbackQuery(@NonNull Query fallbackQuery) {
            Objects.requireNonNull(fallbackQuery, "fallbackQuery is marked non-null but is null");
            this.fallbackQuery = fallbackQuery;
            return this;
        }

        @Generated
        public SparseVectorQueryBuilder filter(Query filter) {
            this.filter = filter;
            return this;
        }

        @Generated
        public SparseVectorQueryBuilder filterResults(Map<Object, BitSet> filterResults) {
            this.filterResults = filterResults;
            return this;
        }

        @Generated
        public SparseVectorQuery build() {
            return new SparseVectorQuery(this.queryVector, this.queryContext, this.fieldName, this.fallbackQuery, this.filter, this.filterResults);
        }

        @Generated
        public String toString() {
            return "SparseVectorQuery.SparseVectorQueryBuilder(queryVector=" + String.valueOf(this.queryVector) + ", queryContext=" + String.valueOf(this.queryContext) + ", fieldName=" + this.fieldName + ", fallbackQuery=" + String.valueOf(this.fallbackQuery) + ", filter=" + String.valueOf(this.filter) + ", filterResults=" + String.valueOf(this.filterResults) + ")";
        }
    }
}

