package org.genemania.engine.core;

import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Matrices;
import no.uib.cipr.matrix.Matrix;
import no.uib.cipr.matrix.MatrixEntry;
import no.uib.cipr.matrix.Vector;
import no.uib.cipr.matrix.sparse.FlexCompColMatrix;
import org.apache.log4j.Logger;
import org.genemania.engine.Constants;
import org.genemania.engine.exception.CancellationException;
import org.genemania.exception.ApplicationException;
import org.genemania.util.ProgressReporter;

/* loaded from: input_file:org/genemania/engine/core/CalculateFastWeights.class */
public class CalculateFastWeights {
    public static final int MAX_VECTOR_LENGTH = 500000;
    public static Logger logger = Logger.getLogger(CalculateFastWeights.class);
    private static double EPSILON = Math.pow(2.0d, -52.0d);
    private static double DELTA = 1.0E-6d;

    public static Map<Integer, Double> CalculateWeights(Matrix matrix, List<Matrix> list, DenseMatrix denseMatrix, DenseMatrix denseMatrix2, CoAnnotationSet coAnnotationSet, Map<Integer, Integer> map, ProgressReporter progressReporter) throws Exception {
        int numColumns = matrix.numColumns();
        computeBasicKtK(list, denseMatrix, matrix.numRows(), progressReporter);
        denseMatrix.scale(numColumns);
        computeKtT(matrix, list, denseMatrix2, coAnnotationSet, progressReporter);
        return SolveForWeights(denseMatrix, denseMatrix2, map, progressReporter);
    }

    public static void computeBasicKtK(List<Matrix> list, DenseMatrix denseMatrix, int i, ProgressReporter progressReporter) throws ApplicationException {
        denseMatrix.set(0, 0, i * i);
        for (int i2 = 0; i2 < list.size(); i2++) {
            logger.info("Currently Processing the " + i2 + "th network");
            double sum = MatrixUtils.sum(list.get(i2));
            denseMatrix.set(i2 + 1, 0, sum);
            denseMatrix.set(0, i2 + 1, sum);
            for (int i3 = 0; i3 <= i2; i3++) {
                if (progressReporter.isCanceled()) {
                    throw new CancellationException();
                }
                double elementMultiplySum = MatrixUtils.elementMultiplySum(list.get(i2), list.get(i3));
                denseMatrix.set(i2 + 1, i3 + 1, elementMultiplySum);
                denseMatrix.set(i3 + 1, i2 + 1, elementMultiplySum);
            }
        }
    }

    public static void computeKtT(Matrix matrix, List<Matrix> list, DenseMatrix denseMatrix, CoAnnotationSet coAnnotationSet, ProgressReporter progressReporter) throws ApplicationException {
        int numColumns = matrix.numColumns();
        int numRows = matrix.numRows();
        Matrix GetCoAnnotationMatrix = coAnnotationSet.GetCoAnnotationMatrix();
        Matrix GetBHalf = coAnnotationSet.GetBHalf();
        double doubleValue = coAnnotationSet.GetConstant().doubleValue();
        MatrixUtils.setDiagonalZero(GetCoAnnotationMatrix);
        logger.info("CoAnnotation with diagonal values removed: " + MatrixUtils.sum(GetCoAnnotationMatrix));
        MatrixUtils.rowSums(GetCoAnnotationMatrix);
        logger.info("Number of Genes " + numRows + ", Number of Categories " + numColumns);
        logger.info("biasValue: " + (numRows * numRows * numColumns));
        denseMatrix.set(0, 0, (MatrixUtils.sum(GetBHalf) * numRows) + MatrixUtils.sum(GetCoAnnotationMatrix) + (doubleValue * numRows * numRows));
        logger.info("Ktt bias value is " + denseMatrix.get(0, 0));
        for (int i = 0; i < list.size(); i++) {
            if (progressReporter.isCanceled()) {
                throw new CancellationException();
            }
            logger.info("Currently Processing the " + i + "th network");
            double sum = MatrixUtils.sum(list.get(i));
            FlexCompColMatrix flexCompColMatrix = new FlexCompColMatrix(1, numRows);
            for (MatrixEntry matrixEntry : list.get(i)) {
                flexCompColMatrix.set(0, matrixEntry.column(), flexCompColMatrix.get(0, matrixEntry.column()) + (matrixEntry.get() * GetBHalf.get(matrixEntry.row(), 0)));
            }
            denseMatrix.set(i + 1, 0, MatrixUtils.elementMultiplySum(list.get(i), GetCoAnnotationMatrix) + MatrixUtils.sum((Matrix) flexCompColMatrix) + (sum * doubleValue));
        }
    }

    public static Map<Integer, Double> SolveForWeights(DenseMatrix denseMatrix, DenseMatrix denseMatrix2, Map<Integer, Integer> map, ProgressReporter progressReporter) throws Exception {
        return Solver.solve(denseMatrix, MatrixUtils.extractColumnToVector(denseMatrix2, 0), map, progressReporter);
    }

    protected static Matrix AllPairs(int i) {
        FlexCompColMatrix flexCompColMatrix = new FlexCompColMatrix((i * (i - 1)) / 2, 2);
        int i2 = 0;
        for (int i3 = 0; i3 < i; i3++) {
            for (int i4 = 0; i4 < i3; i4++) {
                flexCompColMatrix.set(i2, 0, i3);
                flexCompColMatrix.set(i2, 1, i4);
                i2++;
            }
        }
        return flexCompColMatrix;
    }

    protected static Matrix ListToN(int i) {
        FlexCompColMatrix flexCompColMatrix = new FlexCompColMatrix(i, 1);
        for (int i2 = 0; i2 < i; i2++) {
            flexCompColMatrix.set(i2, 0, i2);
        }
        return flexCompColMatrix;
    }

    public static CoAnnotationSet FastCoAnnotation(Matrix matrix) {
        matrix.numColumns();
        int numRows = matrix.numRows();
        int i = numRows * numRows;
        Vector columnSums = MatrixUtils.columnSums(matrix);
        double d = 0.0d;
        for (int i2 = 0; i2 < columnSums.size(); i2++) {
            double d2 = columnSums.get(i2);
            if (d2 > d) {
                d = d2;
            }
        }
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        double d3 = 0.0d;
        double d4 = 0.0d;
        for (int i3 = 0; i3 < columnSums.size(); i3++) {
            d3 += (columnSums.get(i3) * (columnSums.get(i3) - 1.0d)) / 2.0d;
            d4 += columnSums.get(i3);
        }
        int i4 = (((int) d3) / 500000) + 1;
        FlexCompColMatrix flexCompColMatrix = new FlexCompColMatrix(i4, 500000);
        FlexCompColMatrix flexCompColMatrix2 = new FlexCompColMatrix(i4, 500000);
        FlexCompColMatrix flexCompColMatrix3 = new FlexCompColMatrix(i4, 500000);
        int i5 = (((int) d3) / 500000) + 1;
        FlexCompColMatrix flexCompColMatrix4 = new FlexCompColMatrix(i5, 500000);
        new FlexCompColMatrix(i5, 500000);
        FlexCompColMatrix flexCompColMatrix5 = new FlexCompColMatrix(i5, 500000);
        int i6 = 0;
        int i7 = 0;
        double d5 = 0.0d;
        int i8 = 0;
        int i9 = 0;
        int i10 = 0;
        int i11 = 0;
        for (int i12 = 0; i12 < columnSums.size(); i12++) {
            if (columnSums.get(i12) > Constants.DISCRIMINANT_THRESHOLD) {
                int i13 = (int) columnSums.get(i12);
                if (!hashMap.containsKey(Integer.valueOf(i13))) {
                    hashMap.put(Integer.valueOf(i13), AllPairs(i13));
                }
                if (!hashMap2.containsKey(Integer.valueOf(i13))) {
                    hashMap2.put(Integer.valueOf(i13), ListToN(i13));
                }
                FlexCompColMatrix flexCompColMatrix6 = (FlexCompColMatrix) hashMap.get(Integer.valueOf(i13));
                FlexCompColMatrix flexCompColMatrix7 = (FlexCompColMatrix) hashMap2.get(Integer.valueOf(i13));
                int numRows2 = flexCompColMatrix6.numRows();
                int numRows3 = flexCompColMatrix7.numRows();
                Vector findAllNoneZero = MatrixUtils.findAllNoneZero(MatrixUtils.extractColumnToVector(matrix, i12));
                int i14 = 0;
                for (int i15 = i6; i15 < i6 + numRows2; i15++) {
                    flexCompColMatrix.set(i9, i8, findAllNoneZero.get((int) flexCompColMatrix6.get(i14, 0)));
                    flexCompColMatrix2.set(i9, i8, findAllNoneZero.get((int) flexCompColMatrix6.get(i14, 1)));
                    flexCompColMatrix3.set(i9, i8, 1.0d);
                    i14++;
                    i8++;
                    if (i8 == 500000) {
                        i8 = 0;
                        i9++;
                    }
                }
                int i16 = 0;
                for (int i17 = i7; i17 < i7 + numRows3; i17++) {
                    flexCompColMatrix4.set(i11, i10, findAllNoneZero.get((int) flexCompColMatrix7.get(i16, 0)));
                    flexCompColMatrix5.set(i11, i10, ((-2.0d) * i13) / numRows);
                    i16++;
                    i10++;
                    if (i10 == 500000) {
                        i10 = 0;
                        i11++;
                    }
                }
                i6 += numRows2;
                i7 += numRows3;
                d5 += (i13 * i13) / i;
                if (i12 % 100 == 0) {
                    System.out.print(".");
                }
            }
        }
        System.out.println("");
        logger.info("done length");
        FlexCompColMatrix<MatrixEntry> flexCompColMatrix8 = new FlexCompColMatrix(numRows, numRows);
        int i18 = 0;
        int i19 = 0;
        for (int i20 = 0; i20 < d3; i20++) {
            if (flexCompColMatrix8.get((int) flexCompColMatrix.get(i19, i18), (int) flexCompColMatrix2.get(i19, i18)) == Constants.DISCRIMINANT_THRESHOLD) {
                flexCompColMatrix8.set((int) flexCompColMatrix.get(i19, i18), (int) flexCompColMatrix2.get(i19, i18), flexCompColMatrix3.get(i19, i18));
            } else {
                flexCompColMatrix8.set((int) flexCompColMatrix.get(i19, i18), (int) flexCompColMatrix2.get(i19, i18), flexCompColMatrix8.get((int) flexCompColMatrix.get(i19, i18), (int) flexCompColMatrix2.get(i19, i18)) + flexCompColMatrix3.get(i19, i18));
            }
            i18++;
            if (i18 == 500000) {
                i18 = 0;
                i19++;
            }
        }
        FlexCompColMatrix flexCompColMatrix9 = new FlexCompColMatrix(numRows, numRows);
        for (MatrixEntry matrixEntry : flexCompColMatrix8) {
            flexCompColMatrix9.set(matrixEntry.row(), matrixEntry.column(), matrixEntry.get());
            flexCompColMatrix9.set(matrixEntry.column(), matrixEntry.row(), matrixEntry.get());
        }
        logger.info("done lowerhalf the total amount of CoAnnoations are: " + MatrixUtils.sum((Matrix) flexCompColMatrix9));
        FlexCompColMatrix flexCompColMatrix10 = new FlexCompColMatrix(numRows, 1);
        int i21 = 0;
        int i22 = 0;
        for (int i23 = 0; i23 < d4; i23++) {
            if (flexCompColMatrix10.get((int) flexCompColMatrix4.get(i22, i21), 0) == Constants.DISCRIMINANT_THRESHOLD) {
                flexCompColMatrix10.set((int) flexCompColMatrix4.get(i22, i21), 0, flexCompColMatrix5.get(i22, i21));
            } else {
                flexCompColMatrix10.set((int) flexCompColMatrix4.get(i22, i21), 0, flexCompColMatrix10.get((int) flexCompColMatrix4.get(i22, i21), 0) + flexCompColMatrix5.get(i22, i21));
            }
            i21++;
            if (i21 == 500000) {
                i21 = 0;
                i22++;
            }
        }
        logger.info("done BHalf, sum is: " + MatrixUtils.sum((Matrix) flexCompColMatrix10));
        logger.info("constant is : " + d5);
        return new CoAnnotationSet(flexCompColMatrix9, flexCompColMatrix10, d5);
    }

    public static CoAnnotationSet simpleComputeCoAnnoationSet(Matrix matrix) {
        int numRows = matrix.numRows();
        int numColumns = matrix.numColumns();
        Vector simpleComputeSumPosRatios = simpleComputeSumPosRatios(numRows, matrix);
        double simpleComputeConstant = simpleComputeConstant(numRows, simpleComputeSumPosRatios);
        logger.info("constant: " + simpleComputeConstant);
        Matrix simpleComputeYHat = simpleComputeYHat(numRows, simpleComputeSumPosRatios, matrix);
        logger.info("computed YHat");
        Matrix simpleComputeAHatLessMem = simpleComputeAHatLessMem(numRows, numColumns, matrix);
        logger.info("computed AHat");
        return new CoAnnotationSet(simpleComputeAHatLessMem, simpleComputeYHat, simpleComputeConstant);
    }

    public static double simpleComputeConstant(int i, Vector vector) {
        return vector.dot(vector);
    }

    public static Matrix simpleComputeYHat(int i, Vector vector, Matrix matrix) {
        DenseMatrix denseMatrix = new DenseMatrix(i, 1);
        DenseVector denseVector = new DenseVector(i);
        matrix.multAdd(vector, denseVector);
        denseVector.scale(-2.0d);
        for (int i2 = 0; i2 < i; i2++) {
            denseMatrix.set(i2, 0, denseVector.get(i2));
        }
        return denseMatrix;
    }

    public static Matrix simpleComputeAHat(int i, int i2, Matrix matrix) {
        FlexCompColMatrix flexCompColMatrix = new FlexCompColMatrix(i, i);
        int[] iArr = new int[i];
        for (int i3 = 0; i3 < iArr.length; i3++) {
            iArr[i3] = i3;
        }
        for (int i4 = 0; i4 < i2; i4++) {
            logger.debug("processing category " + i4);
            Matrix subMatrix = Matrices.getSubMatrix(matrix, iArr, new int[]{i4});
            FlexCompColMatrix flexCompColMatrix2 = new FlexCompColMatrix(i, i);
            subMatrix.transBmultAdd(1.0d, subMatrix, flexCompColMatrix2);
            for (int i5 = 0; i5 < i; i5++) {
                flexCompColMatrix2.set(i5, i5, Constants.DISCRIMINANT_THRESHOLD);
            }
            flexCompColMatrix.add(flexCompColMatrix2);
        }
        return flexCompColMatrix;
    }

    public static Matrix simpleComputeAHatLessMem(int i, int i2, Matrix matrix) {
        int row;
        int row2;
        FlexCompColMatrix flexCompColMatrix = new FlexCompColMatrix(i, i);
        int[] iArr = new int[i];
        for (int i3 = 0; i3 < iArr.length; i3++) {
            iArr[i3] = i3;
        }
        for (int i4 = 0; i4 < i2; i4++) {
            logger.debug("processing category " + i4);
            Matrix<MatrixEntry> subMatrix = MatrixUtils.getSubMatrix(matrix, iArr, new int[]{i4});
            for (MatrixEntry matrixEntry : subMatrix) {
                for (MatrixEntry matrixEntry2 : subMatrix) {
                    double d = matrixEntry.get() * matrixEntry2.get();
                    if (d != Constants.DISCRIMINANT_THRESHOLD && (row = matrixEntry.row()) != (row2 = matrixEntry2.row())) {
                        flexCompColMatrix.add(row, row2, d);
                    }
                }
            }
        }
        return flexCompColMatrix;
    }

    public static Vector simpleComputeSumPosRatios(int i, Matrix matrix) {
        Vector columnSums = MatrixUtils.columnSums(matrix);
        columnSums.scale(1.0d / i);
        return columnSums;
    }

    public static double simpleComputeKtT0(CoAnnotationSet coAnnotationSet, int i) {
        return (MatrixUtils.sum(coAnnotationSet.GetBHalf()) * i) + MatrixUtils.sum(coAnnotationSet.GetCoAnnotationMatrix()) + (coAnnotationSet.GetConstant().doubleValue() * i * i);
    }

    public static double simpleComputeKtTi(CoAnnotationSet coAnnotationSet, int i, Matrix matrix, double d) {
        FlexCompColMatrix flexCompColMatrix = new FlexCompColMatrix(1, i);
        Iterator it = matrix.iterator();
        while (it.hasNext()) {
            MatrixEntry matrixEntry = (MatrixEntry) it.next();
            flexCompColMatrix.set(0, matrixEntry.column(), flexCompColMatrix.get(0, matrixEntry.column()) + (matrixEntry.get() * coAnnotationSet.GetBHalf().get(matrixEntry.row(), 0)));
        }
        return MatrixUtils.elementMultiplySum(matrix, coAnnotationSet.GetCoAnnotationMatrix()) + MatrixUtils.sum((Matrix) flexCompColMatrix) + (d * coAnnotationSet.GetConstant().doubleValue());
    }
}
