/*
 * Decompiled with CFR 0.152.
 */
package org.apache.solr.client.solrj.io.stream;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.solr.client.solrj.SolrRequest;
import org.apache.solr.client.solrj.impl.CloudSolrClient;
import org.apache.solr.client.solrj.impl.HttpSolrClient;
import org.apache.solr.client.solrj.io.ClassificationEvaluation;
import org.apache.solr.client.solrj.io.SolrClientCache;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.StreamComparator;
import org.apache.solr.client.solrj.io.stream.CloudSolrStream;
import org.apache.solr.client.solrj.io.stream.StreamContext;
import org.apache.solr.client.solrj.io.stream.TupleStream;
import org.apache.solr.client.solrj.io.stream.expr.Explanation;
import org.apache.solr.client.solrj.io.stream.expr.Expressible;
import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import org.apache.solr.client.solrj.request.QueryRequest;
import org.apache.solr.client.solrj.response.QueryResponse;
import org.apache.solr.common.cloud.ClusterState;
import org.apache.solr.common.cloud.Replica;
import org.apache.solr.common.cloud.Slice;
import org.apache.solr.common.cloud.ZkCoreNodeProps;
import org.apache.solr.common.cloud.ZkStateReader;
import org.apache.solr.common.params.ModifiableSolrParams;
import org.apache.solr.common.util.ExecutorUtil;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.common.util.SolrNamedThreadFactory;

public class TextLogitStream
extends TupleStream
implements Expressible {
    private static final long serialVersionUID = 1L;
    protected String zkHost;
    protected String collection;
    protected Map<String, String> params;
    protected String field;
    protected String name;
    protected String outcome;
    protected int positiveLabel;
    protected double threshold;
    protected List<Double> weights;
    protected int maxIterations;
    protected int iteration;
    protected double error;
    protected List<Double> idfs;
    protected ClassificationEvaluation evaluation;
    protected transient SolrClientCache cache;
    protected transient boolean isCloseCache;
    protected transient CloudSolrClient cloudSolrClient;
    protected transient StreamContext streamContext;
    protected ExecutorService executorService;
    protected TupleStream termsStream;
    private List<String> terms;
    private double learningRate = 0.01;
    private double lastError = 0.0;

    public TextLogitStream(String zkHost, String collectionName, Map params, String name, String field, TupleStream termsStream, List<Double> weights, String outcome, int positiveLabel, double threshold, int maxIterations) throws IOException {
        this.init(collectionName, zkHost, params, name, field, termsStream, weights, outcome, positiveLabel, threshold, maxIterations, this.iteration);
    }

    public TextLogitStream(StreamExpression expression, StreamFactory factory) throws IOException {
        String collectionName = factory.getValueOperand(expression, 0);
        List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
        StreamExpressionNamedParameter zkHostExpression = factory.getNamedOperand(expression, "zkHost");
        List<StreamExpression> streamExpressions = factory.getExpressionOperandsRepresentingTypes(expression, Expressible.class, TupleStream.class);
        if (expression.getParameters().size() != 1 + namedParams.size() + streamExpressions.size()) {
            throw new IOException(String.format(Locale.ROOT, "invalid expression %s - unknown operands found", expression));
        }
        if (null == collectionName) {
            throw new IOException(String.format(Locale.ROOT, "invalid expression %s - collectionName expected as first operand", expression));
        }
        if (0 == namedParams.size()) {
            throw new IOException(String.format(Locale.ROOT, "invalid expression %s - at least one named parameter expected. eg. 'q=*:*'", expression));
        }
        HashMap<String, String> params = new HashMap<String, String>();
        for (StreamExpressionNamedParameter namedParam : namedParams) {
            if (namedParam.getName().equals("zkHost")) continue;
            params.put(namedParam.getName(), namedParam.getParameter().toString().trim());
        }
        String name = (String)params.get("name");
        if (name == null) {
            throw new IOException("name param cannot be null for TextLogitStream");
        }
        params.remove("name");
        String feature = (String)params.get("field");
        if (feature == null) {
            throw new IOException("field param cannot be null for TextLogitStream");
        }
        params.remove("field");
        TupleStream stream = null;
        if (streamExpressions.size() <= 0) {
            throw new IOException("features must be present for TextLogitStream");
        }
        stream = factory.constructStream(streamExpressions.get(0));
        String maxIterationsParam = (String)params.get("maxIterations");
        int maxIterations = 0;
        if (maxIterationsParam == null) {
            throw new IOException("maxIterations param cannot be null for TextLogitStream");
        }
        maxIterations = Integer.parseInt(maxIterationsParam);
        params.remove("maxIterations");
        String outcomeParam = (String)params.get("outcome");
        if (outcomeParam == null) {
            throw new IOException("outcome param cannot be null for TextLogitStream");
        }
        params.remove("outcome");
        String positiveLabelParam = (String)params.get("positiveLabel");
        int positiveLabel = 1;
        if (positiveLabelParam != null) {
            positiveLabel = Integer.parseInt(positiveLabelParam);
            params.remove("positiveLabel");
        }
        String thresholdParam = (String)params.get("threshold");
        double threshold = 0.5;
        if (thresholdParam != null) {
            threshold = Double.parseDouble(thresholdParam);
            params.remove("threshold");
        }
        int iteration = 0;
        String iterationParam = (String)params.get("iteration");
        if (iterationParam != null) {
            iteration = Integer.parseInt(iterationParam);
            params.remove("iteration");
        }
        ArrayList<Double> weights = null;
        String weightsParam = (String)params.get("weights");
        if (weightsParam != null) {
            String[] weightsArray;
            weights = new ArrayList<Double>();
            for (String weightString : weightsArray = weightsParam.split(",")) {
                weights.add(Double.parseDouble(weightString));
            }
            params.remove("weights");
        }
        String zkHost = null;
        if (null == zkHostExpression) {
            zkHost = factory.getCollectionZkHost(collectionName);
        } else if (zkHostExpression.getParameter() instanceof StreamExpressionValue) {
            zkHost = ((StreamExpressionValue)zkHostExpression.getParameter()).getValue();
        }
        if (null == zkHost) {
            throw new IOException(String.format(Locale.ROOT, "invalid expression %s - zkHost not found for collection '%s'", expression, collectionName));
        }
        this.init(collectionName, zkHost, params, name, feature, stream, weights, outcomeParam, positiveLabel, threshold, maxIterations, iteration);
    }

    @Override
    public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException {
        return this.toExpression(factory, true);
    }

    private StreamExpression toExpression(StreamFactory factory, boolean includeStreams) throws IOException {
        StreamExpression expression = new StreamExpression(factory.getFunctionName(this.getClass()));
        expression.addParameter(this.collection);
        if (includeStreams && !(this.termsStream instanceof TermsStream)) {
            if (this.termsStream instanceof Expressible) {
                expression.addParameter(((Expressible)((Object)this.termsStream)).toExpression(factory));
            } else {
                throw new IOException("This TextLogitStream contains a non-expressible TupleStream - it cannot be converted to an expression");
            }
        }
        for (Map.Entry<String, String> param : this.params.entrySet()) {
            expression.addParameter(new StreamExpressionNamedParameter(param.getKey(), param.getValue()));
        }
        expression.addParameter(new StreamExpressionNamedParameter("field", this.field));
        expression.addParameter(new StreamExpressionNamedParameter("name", this.name));
        if (this.termsStream instanceof TermsStream) {
            this.loadTerms();
            expression.addParameter(new StreamExpressionNamedParameter("terms", TextLogitStream.toString(this.terms)));
        }
        expression.addParameter(new StreamExpressionNamedParameter("outcome", this.outcome));
        if (this.weights != null) {
            expression.addParameter(new StreamExpressionNamedParameter("weights", TextLogitStream.toString(this.weights)));
        }
        expression.addParameter(new StreamExpressionNamedParameter("maxIterations", Integer.toString(this.maxIterations)));
        if (this.iteration > 0) {
            expression.addParameter(new StreamExpressionNamedParameter("iteration", Integer.toString(this.iteration)));
        }
        expression.addParameter(new StreamExpressionNamedParameter("positiveLabel", Integer.toString(this.positiveLabel)));
        expression.addParameter(new StreamExpressionNamedParameter("threshold", Double.toString(this.threshold)));
        expression.addParameter(new StreamExpressionNamedParameter("zkHost", this.zkHost));
        return expression;
    }

    private void init(String collectionName, String zkHost, Map params, String name, String feature, TupleStream termsStream, List<Double> weights, String outcome, int positiveLabel, double threshold, int maxIterations, int iteration) throws IOException {
        this.zkHost = zkHost;
        this.collection = collectionName;
        this.params = params;
        this.name = name;
        this.field = feature;
        this.termsStream = termsStream;
        this.outcome = outcome;
        this.positiveLabel = positiveLabel;
        this.threshold = threshold;
        this.weights = weights;
        this.maxIterations = maxIterations;
        this.iteration = iteration;
    }

    @Override
    public void setStreamContext(StreamContext context) {
        this.cache = context.getSolrClientCache();
        this.streamContext = context;
        this.termsStream.setStreamContext(context);
    }

    @Override
    public void open() throws IOException {
        if (this.cache == null) {
            this.isCloseCache = true;
            this.cache = new SolrClientCache();
        } else {
            this.isCloseCache = false;
        }
        this.cloudSolrClient = this.cache.getCloudSolrClient(this.zkHost);
        this.executorService = ExecutorUtil.newMDCAwareCachedThreadPool(new SolrNamedThreadFactory("TextLogitSolrStream"));
    }

    @Override
    public List<TupleStream> children() {
        ArrayList<TupleStream> l = new ArrayList<TupleStream>();
        l.add(this.termsStream);
        return l;
    }

    protected List<String> getShardUrls() throws IOException {
        try {
            ZkStateReader zkStateReader = this.cloudSolrClient.getZkStateReader();
            Slice[] slices = CloudSolrStream.getSlices(this.collection, zkStateReader, false);
            ClusterState clusterState = zkStateReader.getClusterState();
            Set<String> liveNodes = clusterState.getLiveNodes();
            ArrayList<String> baseUrls = new ArrayList<String>();
            for (Slice slice : slices) {
                Collection<Replica> replicas = slice.getReplicas();
                ArrayList<Replica> shuffler = new ArrayList<Replica>();
                for (Replica replica : replicas) {
                    if (replica.getState() != Replica.State.ACTIVE || !liveNodes.contains(replica.getNodeName())) continue;
                    shuffler.add(replica);
                }
                Collections.shuffle(shuffler, new Random());
                Replica rep = (Replica)shuffler.get(0);
                ZkCoreNodeProps zkProps = new ZkCoreNodeProps(rep);
                String url = zkProps.getCoreUrl();
                baseUrls.add(url);
            }
            return baseUrls;
        }
        catch (Exception e) {
            throw new IOException(e);
        }
    }

    private List<Future<Tuple>> callShards(List<String> baseUrls) throws IOException {
        ArrayList<Future<Tuple>> futures = new ArrayList<Future<Tuple>>();
        for (String baseUrl : baseUrls) {
            LogitCall lc = new LogitCall(baseUrl, this.params, this.field, this.terms, this.weights, this.outcome, this.positiveLabel, this.learningRate, this.iteration);
            Future<Tuple> future = this.executorService.submit(lc);
            futures.add(future);
        }
        return futures;
    }

    @Override
    public void close() throws IOException {
        if (this.isCloseCache && this.cache != null) {
            this.cache.close();
        }
        if (this.executorService != null) {
            this.executorService.shutdown();
        }
        this.termsStream.close();
    }

    @Override
    public StreamComparator getStreamSort() {
        return null;
    }

    @Override
    public Explanation toExplanation(StreamFactory factory) throws IOException {
        StreamExplanation explanation = new StreamExplanation(this.getStreamNodeId().toString());
        explanation.setFunctionName(factory.getFunctionName(this.getClass()));
        explanation.setImplementingClass(this.getClass().getName());
        explanation.setExpressionType("ml-model");
        explanation.setExpression(this.toExpression(factory).toString());
        explanation.addChild(this.termsStream.toExplanation(factory));
        return explanation;
    }

    public void loadTerms() throws IOException {
        if (this.terms == null) {
            this.termsStream.open();
            this.terms = new ArrayList<String>();
            this.idfs = new ArrayList<Double>();
            while (true) {
                Tuple termTuple = this.termsStream.read();
                if (termTuple.EOF) break;
                this.terms.add(termTuple.getString("term_s"));
                this.idfs.add(termTuple.getDouble("idf_d"));
            }
            this.termsStream.close();
        }
    }

    @Override
    public Tuple read() throws IOException {
        try {
            if (++this.iteration > this.maxIterations) {
                return Tuple.EOF();
            }
            if (this.idfs == null) {
                this.loadTerms();
                if (this.weights != null && this.terms.size() + 1 != this.weights.size()) {
                    throw new IOException(String.format(Locale.ROOT, "invalid expression %s - the number of weights must be %d, found %d", this.terms.size() + 1, this.weights.size()));
                }
            }
            ArrayList<List<Double>> allWeights = new ArrayList<List<Double>>();
            this.evaluation = new ClassificationEvaluation();
            this.error = 0.0;
            for (Future<Tuple> logitCall : this.callShards(this.getShardUrls())) {
                Tuple tuple = logitCall.get();
                List shardWeights = (List)tuple.get("weights");
                allWeights.add(shardWeights);
                this.error += tuple.getDouble("error").doubleValue();
                Map shardEvaluation = (Map)tuple.get("evaluation");
                this.evaluation.addEvaluation(shardEvaluation);
            }
            this.weights = this.averageWeights(allWeights);
            HashMap<String, Object> map = new HashMap<String, Object>();
            map.put("id", this.name + "_" + this.iteration);
            map.put("name_s", this.name);
            map.put("field_s", this.field);
            map.put("terms_ss", this.terms);
            map.put("iteration_i", this.iteration);
            if (this.weights != null) {
                map.put("weights_ds", this.weights);
            }
            map.put("error_d", this.error);
            this.evaluation.putToMap(map);
            map.put("alpha_d", this.learningRate);
            map.put("idfs_ds", this.idfs);
            if (this.iteration != 1) {
                this.learningRate = this.lastError <= this.error ? (this.learningRate *= 0.5) : (this.learningRate *= 1.05);
            }
            this.lastError = this.error;
            return new Tuple(map);
        }
        catch (Exception e) {
            throw new IOException(e);
        }
    }

    private List<Double> averageWeights(List<List<Double>> allWeights) {
        double[] working = new double[allWeights.get(0).size()];
        for (List<Double> shardWeights : allWeights) {
            for (int i = 0; i < working.length; ++i) {
                int n = i;
                working[n] = working[n] + shardWeights.get(i);
            }
        }
        for (int i = 0; i < working.length; ++i) {
            working[i] = working[i] / (double)allWeights.size();
        }
        ArrayList<Double> ave = new ArrayList<Double>();
        for (double d : working) {
            ave.add(d);
        }
        return ave;
    }

    static String toString(List items) {
        StringBuilder buf = new StringBuilder();
        for (Object item : items) {
            if (buf.length() > 0) {
                buf.append(",");
            }
            buf.append(item.toString());
        }
        return buf.toString();
    }

    protected class LogitCall
    implements Callable<Tuple> {
        private String baseUrl;
        private String feature;
        private List<String> terms;
        private List<Double> weights;
        private int iteration;
        private String outcome;
        private int positiveLabel;
        private double learningRate;
        private Map<String, String> paramsMap;

        public LogitCall(String baseUrl, Map<String, String> paramsMap, String feature, List<String> terms, List<Double> weights, String outcome, int positiveLabel, double learningRate, int iteration) {
            this.baseUrl = baseUrl;
            this.feature = feature;
            this.terms = terms;
            this.weights = weights;
            this.iteration = iteration;
            this.outcome = outcome;
            this.positiveLabel = positiveLabel;
            this.learningRate = learningRate;
            this.paramsMap = paramsMap;
        }

        @Override
        public Tuple call() throws Exception {
            ModifiableSolrParams params = new ModifiableSolrParams();
            HttpSolrClient solrClient = TextLogitStream.this.cache.getHttpSolrClient(this.baseUrl);
            params.add("distrib", "false");
            params.add("fq", "{!tlogit}");
            params.add("feature", this.feature);
            params.add("terms", TextLogitStream.toString(this.terms));
            params.add("idfs", TextLogitStream.toString(TextLogitStream.this.idfs));
            for (Map.Entry<String, String> entry : this.paramsMap.entrySet()) {
                params.add(entry.getKey(), entry.getValue());
            }
            if (this.weights != null) {
                params.add("weights", TextLogitStream.toString(this.weights));
            }
            params.add("iteration", Integer.toString(this.iteration));
            params.add("outcome", this.outcome);
            params.add("positiveLabel", Integer.toString(this.positiveLabel));
            params.add("threshold", Double.toString(TextLogitStream.this.threshold));
            params.add("alpha", Double.toString(this.learningRate));
            QueryRequest request = new QueryRequest(params, SolrRequest.METHOD.POST);
            QueryResponse response = (QueryResponse)request.process(solrClient);
            NamedList<Object> res = response.getResponse();
            NamedList logit = (NamedList)res.get("logit");
            List shardWeights = (List)logit.get("weights");
            double shardError = (Double)logit.get("error");
            Tuple tuple = new Tuple();
            tuple.put("error", shardError);
            tuple.put("weights", shardWeights);
            tuple.put("evaluation", logit.get("evaluation"));
            return tuple;
        }
    }

    protected static class TermsStream
    extends TupleStream {
        private List<String> terms;
        private Iterator<String> it;

        public TermsStream(List<String> terms) {
            this.terms = terms;
        }

        @Override
        public void setStreamContext(StreamContext context) {
        }

        @Override
        public List<TupleStream> children() {
            return new ArrayList<TupleStream>();
        }

        @Override
        public void open() throws IOException {
            this.it = this.terms.iterator();
        }

        @Override
        public void close() throws IOException {
        }

        @Override
        public Tuple read() throws IOException {
            if (this.it.hasNext()) {
                Tuple tuple = new Tuple();
                tuple.put("term_s", this.it.next());
                tuple.put("score_f", 1.0);
                return tuple;
            }
            return Tuple.EOF();
        }

        @Override
        public StreamComparator getStreamSort() {
            return null;
        }

        @Override
        public Explanation toExplanation(StreamFactory factory) throws IOException {
            return new StreamExplanation(this.getStreamNodeId().toString()).withFunctionName("non-expressible").withImplementingClass(this.getClass().getName()).withExpressionType("stream-source").withExpression("non-expressible");
        }
    }
}

