package org.apache.solr.handler;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.StreamComparator;
import org.apache.solr.client.solrj.io.stream.StreamContext;
import org.apache.solr.client.solrj.io.stream.TupleStream;
import org.apache.solr.client.solrj.io.stream.expr.Explanation;
import org.apache.solr.client.solrj.io.stream.expr.Expressible;
import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import org.apache.solr.common.SolrException;
import org.apache.solr.core.SolrCore;

/* loaded from: input_file:libs/solr-core-6.6.5-patched.11.jar:org/apache/solr/handler/ClassifyStream.class */
public class ClassifyStream extends TupleStream implements Expressible {
    private TupleStream docStream;
    private TupleStream modelStream;
    private String field;
    private String analyzerField;
    private Tuple modelTuple;
    Analyzer analyzer;
    private Map<CharSequence, Integer> termToIndex;
    private List<Double> idfs;
    private List<Double> modelWeights;

    public ClassifyStream(StreamExpression streamExpression, StreamFactory streamFactory) throws IOException {
        List<StreamExpression> expressionOperandsRepresentingTypes = streamFactory.getExpressionOperandsRepresentingTypes(streamExpression, Expressible.class, TupleStream.class);
        if (expressionOperandsRepresentingTypes.size() != 2) {
            throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - expecting two stream but found %d", streamExpression, Integer.valueOf(expressionOperandsRepresentingTypes.size())));
        }
        this.modelStream = streamFactory.constructStream(expressionOperandsRepresentingTypes.get(0));
        this.docStream = streamFactory.constructStream(expressionOperandsRepresentingTypes.get(1));
        StreamExpressionNamedParameter namedOperand = streamFactory.getNamedOperand(streamExpression, "field");
        if (namedOperand == null) {
            throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - field parameter must be specified", streamExpression, Integer.valueOf(expressionOperandsRepresentingTypes.size())));
        }
        String obj = namedOperand.getParameter().toString();
        this.field = obj;
        this.analyzerField = obj;
        StreamExpressionNamedParameter namedOperand2 = streamFactory.getNamedOperand(streamExpression, "analyzerField");
        if (namedOperand2 != null) {
            this.analyzerField = namedOperand2.getParameter().toString();
        }
    }

    @Override // org.apache.solr.client.solrj.io.stream.TupleStream
    public void setStreamContext(StreamContext streamContext) {
        Object obj = streamContext.get("solr-core");
        if (obj == null || !(obj instanceof SolrCore)) {
            throw new SolrException(SolrException.ErrorCode.INVALID_STATE, "StreamContext must have SolrCore in solr-core key");
        }
        this.analyzer = ((SolrCore) obj).getLatestSchema().getFieldType(this.analyzerField).getIndexAnalyzer();
        this.docStream.setStreamContext(streamContext);
        this.modelStream.setStreamContext(streamContext);
    }

    @Override // org.apache.solr.client.solrj.io.stream.TupleStream
    public List<TupleStream> children() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(this.docStream);
        arrayList.add(this.modelStream);
        return arrayList;
    }

    @Override // org.apache.solr.client.solrj.io.stream.TupleStream
    public void open() throws IOException {
        this.docStream.open();
        this.modelStream.open();
    }

    @Override // org.apache.solr.client.solrj.io.stream.TupleStream, java.io.Closeable, java.lang.AutoCloseable
    public void close() throws IOException {
        this.docStream.close();
        this.modelStream.close();
    }

    @Override // org.apache.solr.client.solrj.io.stream.TupleStream
    public Tuple read() throws IOException {
        if (this.modelTuple == null) {
            this.modelTuple = this.modelStream.read();
            if (this.modelTuple == null || this.modelTuple.EOF) {
                throw new IOException("Model tuple not found for classify stream!");
            }
            this.termToIndex = new HashMap();
            List<String> strings = this.modelTuple.getStrings("terms_ss");
            for (int i = 0; i < strings.size(); i++) {
                this.termToIndex.put(strings.get(i), Integer.valueOf(i));
            }
            this.idfs = this.modelTuple.getDoubles("idfs_ds");
            this.modelWeights = this.modelTuple.getDoubles("weights_ds");
        }
        Tuple read = this.docStream.read();
        if (read.EOF) {
            return read;
        }
        String string = read.getString(this.field);
        double[] dArr = new double[this.termToIndex.size()];
        TokenStream tokenStream = this.analyzer.tokenStream(this.analyzerField, string);
        CharTermAttribute charTermAttribute = (CharTermAttribute) tokenStream.getAttribute(CharTermAttribute.class);
        tokenStream.reset();
        int i2 = 0;
        while (tokenStream.incrementToken()) {
            i2++;
            if (this.termToIndex.containsKey(charTermAttribute.toString())) {
                int intValue = this.termToIndex.get(charTermAttribute.toString()).intValue();
                dArr[intValue] = dArr[intValue] + 1.0d;
            }
        }
        tokenStream.end();
        tokenStream.close();
        ArrayList arrayList = new ArrayList(this.termToIndex.size());
        arrayList.add(Double.valueOf(1.0d));
        for (int i3 = 0; i3 < dArr.length; i3++) {
            if (dArr[i3] != 0.0d) {
                dArr[i3] = 1.0d + Math.log(dArr[i3]);
            }
            arrayList.add(Double.valueOf(this.idfs.get(i3).doubleValue() * dArr[i3]));
        }
        double d = 0.0d;
        for (int i4 = 0; i4 < arrayList.size(); i4++) {
            d += ((Double) arrayList.get(i4)).doubleValue() * this.modelWeights.get(i4).doubleValue();
        }
        read.put("probability_d", Double.valueOf(sigmoid(d)));
        read.put("score_d", Double.valueOf(d * ((float) (1.0d / Math.sqrt(i2)))));
        return read;
    }

    private double sigmoid(double d) {
        return 1.0d / (1.0d + Math.exp(-d));
    }

    @Override // org.apache.solr.client.solrj.io.stream.TupleStream
    public StreamComparator getStreamSort() {
        return null;
    }

    @Override // org.apache.solr.client.solrj.io.stream.expr.Expressible
    public StreamExpressionParameter toExpression(StreamFactory streamFactory) throws IOException {
        return toExpression(streamFactory, true);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private StreamExpression toExpression(StreamFactory streamFactory, boolean z) throws IOException {
        StreamExpression streamExpression = new StreamExpression(streamFactory.getFunctionName(getClass()));
        if (z) {
            if (!(this.docStream instanceof Expressible) || !(this.modelStream instanceof Expressible)) {
                throw new IOException("This ClassifyStream contains a non-expressible TupleStream - it cannot be converted to an expression");
            }
            streamExpression.addParameter(((Expressible) this.modelStream).toExpression(streamFactory));
            streamExpression.addParameter(((Expressible) this.docStream).toExpression(streamFactory));
        }
        streamExpression.addParameter(new StreamExpressionNamedParameter("field", this.field));
        streamExpression.addParameter(new StreamExpressionNamedParameter("analyzerField", this.analyzerField));
        return streamExpression;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.apache.solr.client.solrj.io.stream.TupleStream, org.apache.solr.client.solrj.io.stream.expr.Expressible
    public Explanation toExplanation(StreamFactory streamFactory) throws IOException {
        StreamExplanation streamExplanation = new StreamExplanation(getStreamNodeId().toString());
        streamExplanation.setFunctionName(streamFactory.getFunctionName(getClass()));
        streamExplanation.setImplementingClass(getClass().getName());
        streamExplanation.setExpressionType(Explanation.ExpressionType.STREAM_DECORATOR);
        streamExplanation.setExpression(toExpression(streamFactory, false).toString());
        streamExplanation.addChild(this.docStream.toExplanation(streamFactory));
        streamExplanation.addChild(this.modelStream.toExplanation(streamFactory));
        return streamExplanation;
    }
}
