package org.genemania.engine.core.integration;

import java.util.List;
import java.util.Map;
import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Vector;
import no.uib.cipr.matrix.sparse.FlexCompColMatrix;
import org.apache.log4j.Logger;
import org.genemania.engine.Constants;
import org.genemania.engine.cache.DataCache;
import org.genemania.engine.config.Config;
import org.genemania.engine.core.MatrixUtils;
import org.genemania.engine.core.data.CoAnnotationSet;
import org.genemania.engine.core.data.Data;
import org.genemania.engine.exception.CancellationException;
import org.genemania.engine.matricks.Matrix;
import org.genemania.engine.matricks.MatrixCursor;
import org.genemania.engine.matricks.SymMatrix;
import org.genemania.exception.ApplicationException;
import org.genemania.util.ProgressReporter;

/* loaded from: input_file:org/genemania/engine/core/integration/CalculateFastWeights2.class */
public class CalculateFastWeights2 {
    public static final int MAX_VECTOR_LENGTH = 500000;
    public static Logger logger = Logger.getLogger(CalculateFastWeights2.class);
    private static double EPSILON = Math.pow(2.0d, -52.0d);
    private static double DELTA = 1.0E-6d;

    private static Map<Long, Double> CalculateWeights(Matrix matrix, List<SymMatrix> list, DenseMatrix denseMatrix, DenseMatrix denseMatrix2, CoAnnotationSet coAnnotationSet, Map<Integer, Long> map, ProgressReporter progressReporter) throws Exception {
        return null;
    }

    public static void computeBasicKtK(DataCache dataCache, long j, long[] jArr, DenseMatrix denseMatrix, int i, ProgressReporter progressReporter) throws ApplicationException {
        int length = jArr.length;
        denseMatrix.set(0, 0, i * i);
        for (int i2 = 0; i2 < length; i2++) {
            long j2 = jArr[i2];
            SymMatrix data = dataCache.getNetwork(Data.CORE, j, j2).getData();
            logger.info("Currently Processing the " + i2 + "th network, id " + j2);
            double elementSum = data.elementSum();
            denseMatrix.set(i2 + 1, 0, elementSum);
            denseMatrix.set(0, i2 + 1, elementSum);
            for (int i3 = 0; i3 <= i2; i3++) {
                if (progressReporter.isCanceled()) {
                    throw new CancellationException();
                }
                double elementMultiplySum = data.elementMultiplySum(dataCache.getNetwork(Data.CORE, j, jArr[i3]).getData());
                denseMatrix.set(i2 + 1, i3 + 1, elementMultiplySum);
                denseMatrix.set(i3 + 1, i2 + 1, elementMultiplySum);
            }
        }
    }

    public static void computeKtT(DataCache dataCache, long j, Matrix matrix, long[] jArr, DenseMatrix denseMatrix, CoAnnotationSet coAnnotationSet, ProgressReporter progressReporter) throws ApplicationException {
        int numCols = matrix.numCols();
        int numRows = matrix.numRows();
        SymMatrix GetCoAnnotationMatrix = coAnnotationSet.GetCoAnnotationMatrix();
        DenseVector GetBHalf = coAnnotationSet.GetBHalf();
        double doubleValue = coAnnotationSet.GetConstant().doubleValue();
        GetCoAnnotationMatrix.setDiag(Constants.DISCRIMINANT_THRESHOLD);
        int length = jArr.length;
        logger.info("Number of Genes " + numRows + ", Number of Categories " + numCols + ", Number of networks: " + length);
        logger.info("biasValue: " + (numRows * numRows * numCols));
        denseMatrix.set(0, 0, (MatrixUtils.sum((Vector) GetBHalf) * numRows) + GetCoAnnotationMatrix.elementSum() + (doubleValue * numRows * numRows));
        logger.info("Ktt bias value is " + denseMatrix.get(0, 0));
        for (int i = 0; i < length; i++) {
            if (progressReporter.isCanceled()) {
                throw new CancellationException();
            }
            logger.info("Currently Processing the " + i + "th network");
            denseMatrix.set(i + 1, 0, computeKttElement(numRows, dataCache.getNetwork(Data.CORE, j, jArr[i]).getData(), GetCoAnnotationMatrix, GetBHalf, doubleValue));
        }
    }

    public static double computeKttElement(int i, SymMatrix symMatrix, SymMatrix symMatrix2, DenseVector denseVector, double d) {
        double elementSum = symMatrix.elementSum();
        DenseVector denseVector2 = new DenseVector(i);
        symMatrix.mult(denseVector.getData(), denseVector2.getData());
        return symMatrix.elementMultiplySum(symMatrix2) + MatrixUtils.sum((Vector) denseVector2) + (elementSum * d);
    }

    public static Map<Long, Double> SolveForWeights(DenseMatrix denseMatrix, DenseMatrix denseMatrix2, Map<Integer, Long> map, ProgressReporter progressReporter) throws Exception {
        return Solver.solve(denseMatrix, MatrixUtils.extractColumnToVector(denseMatrix2, 0), map, progressReporter);
    }

    protected static no.uib.cipr.matrix.Matrix AllPairs(int i) {
        FlexCompColMatrix flexCompColMatrix = new FlexCompColMatrix((i * (i - 1)) / 2, 2);
        int i2 = 0;
        for (int i3 = 0; i3 < i; i3++) {
            for (int i4 = 0; i4 < i3; i4++) {
                flexCompColMatrix.set(i2, 0, i3);
                flexCompColMatrix.set(i2, 1, i4);
                i2++;
            }
        }
        return flexCompColMatrix;
    }

    protected static no.uib.cipr.matrix.Matrix ListToN(int i) {
        FlexCompColMatrix flexCompColMatrix = new FlexCompColMatrix(i, 1);
        for (int i2 = 0; i2 < i; i2++) {
            flexCompColMatrix.set(i2, 0, i2);
        }
        return flexCompColMatrix;
    }

    public static CoAnnotationSet simpleComputeCoAnnoationSet(long j, String str, Matrix matrix) {
        int numRows = matrix.numRows();
        int numCols = matrix.numCols();
        double[] simpleComputeSumPosRatios = simpleComputeSumPosRatios(numRows, matrix);
        double simpleComputeConstant = simpleComputeConstant(numRows, new DenseVector(simpleComputeSumPosRatios, false));
        logger.info("constant: " + simpleComputeConstant);
        DenseVector simpleComputeYHat = simpleComputeYHat(numRows, simpleComputeSumPosRatios, matrix);
        logger.info("computed YHat");
        SymMatrix simpleComputeAHatLessMem = simpleComputeAHatLessMem(numRows, numCols, matrix);
        logger.info("computed AHat");
        return new CoAnnotationSet(j, str, simpleComputeAHatLessMem, simpleComputeYHat, simpleComputeConstant);
    }

    public static double simpleComputeConstant(int i, Vector vector) {
        return vector.dot(vector);
    }

    public static DenseVector simpleComputeYHat(int i, double[] dArr, Matrix matrix) {
        double[] dArr2 = new double[i];
        matrix.multAdd(dArr, dArr2);
        DenseVector denseVector = new DenseVector(dArr2, false);
        denseVector.scale(-2.0d);
        return denseVector;
    }

    public static SymMatrix simpleComputeAHatLessMem(int i, int i2, Matrix matrix) {
        SymMatrix symSparseMatrix = Config.instance().getMatrixFactory().symSparseMatrix(i);
        int[] iArr = new int[i];
        for (int i3 = 0; i3 < iArr.length; i3++) {
            iArr[i3] = i3;
        }
        for (int i4 = 0; i4 < i2; i4++) {
            Matrix subMatrix = matrix.subMatrix(iArr, new int[]{i4});
            MatrixCursor cursor = subMatrix.cursor();
            while (cursor.next()) {
                MatrixCursor cursor2 = subMatrix.cursor();
                while (cursor2.next()) {
                    int row = cursor.row();
                    int row2 = cursor2.row();
                    if (row > row2) {
                        double val = cursor.val() * cursor2.val();
                        if (val != Constants.DISCRIMINANT_THRESHOLD) {
                            symSparseMatrix.add(row, row2, val);
                        }
                    }
                }
            }
        }
        return symSparseMatrix;
    }

    public static double[] simpleComputeSumPosRatios(int i, Matrix matrix) {
        double[] dArr = new double[matrix.numCols()];
        matrix.columnSums(dArr);
        new DenseVector(dArr, false).scale(1.0d / i);
        return dArr;
    }

    public static double simpleComputeKtT0(CoAnnotationSet coAnnotationSet, int i) {
        return (MatrixUtils.sum((Vector) coAnnotationSet.GetBHalf()) * i) + coAnnotationSet.GetCoAnnotationMatrix().elementSum() + (coAnnotationSet.GetConstant().doubleValue() * i * i);
    }

    public static double simpleComputeKtTi(CoAnnotationSet coAnnotationSet, int i, Matrix matrix, double d) {
        FlexCompColMatrix flexCompColMatrix = new FlexCompColMatrix(1, i);
        MatrixCursor cursor = matrix.cursor();
        while (cursor.next()) {
            flexCompColMatrix.set(0, cursor.col(), flexCompColMatrix.get(0, cursor.col()) + (cursor.val() * coAnnotationSet.GetBHalf().get(cursor.row())));
        }
        return matrix.elementMultiplySum(coAnnotationSet.GetCoAnnotationMatrix()) + MatrixUtils.sum((no.uib.cipr.matrix.Matrix) flexCompColMatrix) + (d * coAnnotationSet.GetConstant().doubleValue());
    }
}
