package org.apache.lucene.classification.document;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.SimpleNaiveBayesClassifier;
import org.apache.lucene.document.Document;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.MultiFields;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.util.BytesRef;

/* loaded from: input_file:libs/lucene-classification-6.6.5-patched.11.jar:org/apache/lucene/classification/document/SimpleNaiveBayesDocumentClassifier.class */
public class SimpleNaiveBayesDocumentClassifier extends SimpleNaiveBayesClassifier implements DocumentClassifier<BytesRef> {
    protected Map<String, Analyzer> field2analyzer;

    public SimpleNaiveBayesDocumentClassifier(IndexReader indexReader, Query query, String str, Map<String, Analyzer> map, String... strArr) {
        super(indexReader, null, query, str, strArr);
        this.field2analyzer = map;
    }

    @Override // org.apache.lucene.classification.document.DocumentClassifier
    public ClassificationResult<BytesRef> assignClass(Document document) throws IOException {
        ClassificationResult<BytesRef> classificationResult = null;
        double d = -1.7976931348623157E308d;
        for (ClassificationResult<BytesRef> classificationResult2 : assignNormClasses(document)) {
            if (classificationResult2.getScore() > d) {
                classificationResult = classificationResult2;
                d = classificationResult2.getScore();
            }
        }
        return classificationResult;
    }

    @Override // org.apache.lucene.classification.document.DocumentClassifier
    public List<ClassificationResult<BytesRef>> getClasses(Document document) throws IOException {
        List<ClassificationResult<BytesRef>> assignNormClasses = assignNormClasses(document);
        Collections.sort(assignNormClasses);
        return assignNormClasses;
    }

    @Override // org.apache.lucene.classification.document.DocumentClassifier
    public List<ClassificationResult<BytesRef>> getClasses(Document document, int i) throws IOException {
        List<ClassificationResult<BytesRef>> assignNormClasses = assignNormClasses(document);
        Collections.sort(assignNormClasses);
        return assignNormClasses.subList(0, i);
    }

    private List<ClassificationResult<BytesRef>> assignNormClasses(Document document) throws IOException {
        ArrayList arrayList = new ArrayList();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        LinkedHashMap linkedHashMap2 = new LinkedHashMap();
        TermsEnum it = MultiFields.getTerms(this.indexReader, this.classFieldName).iterator();
        analyzeSeedDocument(document, linkedHashMap, linkedHashMap2);
        int countDocsWithClass = countDocsWithClass();
        while (true) {
            BytesRef next = it.next();
            if (next == null) {
                return normClassificationResults(arrayList);
            }
            double d = 0.0d;
            Term term = new Term(this.classFieldName, next);
            for (String str : this.textFieldNames) {
                double d2 = 0.0d;
                Iterator<String[]> it2 = linkedHashMap.get(str).iterator();
                while (it2.hasNext()) {
                    d2 += calculateLogPrior(term, countDocsWithClass) + (calculateLogLikelihood(it2.next(), str, term, countDocsWithClass) * linkedHashMap2.get(str).floatValue());
                }
                d += d2;
            }
            arrayList.add(new ClassificationResult<>(term.bytes(), d));
        }
    }

    private void analyzeSeedDocument(Document document, Map<String, List<String[]>> map, Map<String, Float> map2) throws IOException {
        for (int i = 0; i < this.textFieldNames.length; i++) {
            String str = this.textFieldNames[i];
            float f = 1.0f;
            LinkedList linkedList = new LinkedList();
            if (str.contains("^")) {
                String[] split = str.split("\\^");
                str = split[0];
                f = Float.parseFloat(split[1]);
            }
            for (IndexableField indexableField : document.getFields(str)) {
                linkedList.add(getTokenArray(indexableField.tokenStream(this.field2analyzer.get(str), null)));
            }
            map.put(str, linkedList);
            map2.put(str, Float.valueOf(f));
            this.textFieldNames[i] = str;
        }
    }

    protected String[] getTokenArray(TokenStream tokenStream) throws IOException {
        LinkedList linkedList = new LinkedList();
        CharTermAttribute charTermAttribute = (CharTermAttribute) tokenStream.addAttribute(CharTermAttribute.class);
        tokenStream.reset();
        while (tokenStream.incrementToken()) {
            linkedList.add(charTermAttribute.toString());
        }
        tokenStream.end();
        tokenStream.close();
        return (String[]) linkedList.toArray(new String[linkedList.size()]);
    }

    private double calculateLogLikelihood(String[] strArr, String str, Term term, int i) throws IOException {
        double d = 0.0d;
        for (String str2 : strArr) {
            d += Math.log((getWordFreqForClass(str2, str, term) + 1) / (getTextTermFreqForClass(term, str) + i));
        }
        return d / strArr.length;
    }

    private double getTextTermFreqForClass(Term term, String str) throws IOException {
        Terms terms = MultiFields.getTerms(this.indexReader, str);
        return (terms.getSumDocFreq() / terms.getDocCount()) * this.indexReader.docFreq(term);
    }

    private int getWordFreqForClass(String str, String str2, Term term) throws IOException {
        BooleanQuery.Builder builder = new BooleanQuery.Builder();
        BooleanQuery.Builder builder2 = new BooleanQuery.Builder();
        builder2.add(new BooleanClause(new TermQuery(new Term(str2, str)), BooleanClause.Occur.SHOULD));
        builder.add(new BooleanClause(builder2.build(), BooleanClause.Occur.MUST));
        builder.add(new BooleanClause(new TermQuery(term), BooleanClause.Occur.MUST));
        if (this.query != null) {
            builder.add(this.query, BooleanClause.Occur.MUST);
        }
        TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector();
        this.indexSearcher.search(builder.build(), totalHitCountCollector);
        return totalHitCountCollector.getTotalHits();
    }

    private double calculateLogPrior(Term term, int i) throws IOException {
        return Math.log(docCount(term)) - Math.log(i);
    }

    private int docCount(Term term) throws IOException {
        return this.indexReader.docFreq(term);
    }
}
