package org.carrot2.text.vsm;

import org.carrot2.core.attribute.Processing;
import org.carrot2.mahout.math.matrix.DoubleMatrix2D;
import org.carrot2.mahout.math.matrix.impl.DenseDoubleMatrix2D;
import org.carrot2.matrix.MatrixUtils;
import org.carrot2.matrix.factorization.IMatrixFactorization;
import org.carrot2.matrix.factorization.IMatrixFactorizationFactory;
import org.carrot2.matrix.factorization.IterationNumberGuesser;
import org.carrot2.matrix.factorization.IterativeMatrixFactorizationFactory;
import org.carrot2.matrix.factorization.KMeansMatrixFactorizationFactory;
import org.carrot2.matrix.factorization.LocalNonnegativeMatrixFactorizationFactory;
import org.carrot2.matrix.factorization.NonnegativeMatrixFactorizationEDFactory;
import org.carrot2.matrix.factorization.NonnegativeMatrixFactorizationKLFactory;
import org.carrot2.matrix.factorization.PartialSingularValueDecompositionFactory;
import org.carrot2.util.attribute.Attribute;
import org.carrot2.util.attribute.AttributeLevel;
import org.carrot2.util.attribute.Bindable;
import org.carrot2.util.attribute.Group;
import org.carrot2.util.attribute.Input;
import org.carrot2.util.attribute.Label;
import org.carrot2.util.attribute.Level;
import org.carrot2.util.attribute.Required;
import org.carrot2.util.attribute.constraint.ImplementingClasses;

@Bindable(prefix = "TermDocumentMatrixReducer")
/* loaded from: input_file:libs/carrot2-mini-3.15.0.jar:org/carrot2/text/vsm/TermDocumentMatrixReducer.class */
public class TermDocumentMatrixReducer {

    @Level(AttributeLevel.ADVANCED)
    @Input
    @Attribute
    @ImplementingClasses(classes = {PartialSingularValueDecompositionFactory.class, NonnegativeMatrixFactorizationEDFactory.class, NonnegativeMatrixFactorizationKLFactory.class, LocalNonnegativeMatrixFactorizationFactory.class, KMeansMatrixFactorizationFactory.class}, strict = false)
    @Group(TermDocumentMatrixBuilder.MATRIX_MODEL)
    @Processing
    @Required
    @Label("Factorization method")
    public IMatrixFactorizationFactory factorizationFactory = new NonnegativeMatrixFactorizationEDFactory();

    @Level(AttributeLevel.ADVANCED)
    @Input
    @Attribute
    @Group(TermDocumentMatrixBuilder.MATRIX_MODEL)
    @Processing
    @Required
    @Label("Factorization quality")
    public IterationNumberGuesser.FactorizationQuality factorizationQuality = IterationNumberGuesser.FactorizationQuality.HIGH;

    public void reduce(ReducedVectorSpaceModelContext reducedVectorSpaceModelContext, int i) {
        VectorSpaceModelContext vectorSpaceModelContext = reducedVectorSpaceModelContext.vsmContext;
        if (vectorSpaceModelContext.termDocumentMatrix.columns() == 0 || vectorSpaceModelContext.termDocumentMatrix.rows() == 0) {
            reducedVectorSpaceModelContext.baseMatrix = new DenseDoubleMatrix2D(vectorSpaceModelContext.termDocumentMatrix.rows(), vectorSpaceModelContext.termDocumentMatrix.columns());
            return;
        }
        if (this.factorizationFactory instanceof IterativeMatrixFactorizationFactory) {
            ((IterativeMatrixFactorizationFactory) this.factorizationFactory).setK(i);
            IterationNumberGuesser.setEstimatedIterationsNumber((IterativeMatrixFactorizationFactory) this.factorizationFactory, vectorSpaceModelContext.termDocumentMatrix, this.factorizationQuality);
        }
        MatrixUtils.normalizeColumnL2(vectorSpaceModelContext.termDocumentMatrix, null);
        IMatrixFactorization factorize = this.factorizationFactory.factorize(vectorSpaceModelContext.termDocumentMatrix);
        reducedVectorSpaceModelContext.baseMatrix = factorize.getU();
        reducedVectorSpaceModelContext.coefficientMatrix = factorize.getV();
        reducedVectorSpaceModelContext.baseMatrix = trim(factorize.getU(), i);
        reducedVectorSpaceModelContext.coefficientMatrix = trim(factorize.getV(), i);
    }

    private final DoubleMatrix2D trim(DoubleMatrix2D doubleMatrix2D, int i) {
        return ((this.factorizationFactory instanceof IterativeMatrixFactorizationFactory) || doubleMatrix2D.columns() <= i) ? doubleMatrix2D : doubleMatrix2D.viewPart(0, 0, doubleMatrix2D.rows(), i);
    }
}
