package org.genemania.engine.core;

import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Matrices;
import no.uib.cipr.matrix.Matrix;
import no.uib.cipr.matrix.MatrixSingularException;
import no.uib.cipr.matrix.QR;
import no.uib.cipr.matrix.Vector;
import no.uib.cipr.matrix.sparse.FlexCompColMatrix;
import org.apache.log4j.Logger;
import org.genemania.engine.Constants;

/* loaded from: input_file:org/genemania/engine/core/CombineNetworks.class */
public class CombineNetworks {
    private static Logger logger = Logger.getLogger(CombineNetworks.class);
    public static double EPSILON = Math.pow(2.0d, -52.0d);
    public static double DELTA = 1.0E-6d;

    public static Matrix combine(List<Matrix> list, Vector vector, Constants.CombiningMethod combiningMethod) throws Exception {
        if (combiningMethod == Constants.CombiningMethod.AVERAGE) {
            return combine(average(list));
        }
        if (combiningMethod == Constants.CombiningMethod.AUTOMATIC) {
            return combine(automatic(list, vector));
        }
        throw new Exception("Unsupported network combination method: " + combiningMethod);
    }

    public static Map<Matrix, Double> computeWeights(List<Matrix> list, Vector vector, Constants.CombiningMethod combiningMethod) throws Exception {
        if (combiningMethod == Constants.CombiningMethod.AVERAGE) {
            return average(list);
        }
        if (combiningMethod == Constants.CombiningMethod.AUTOMATIC) {
            return automatic(list, vector);
        }
        throw new Exception("Unsupported network combination method: " + combiningMethod);
    }

    private static Map<Matrix, Double> average(List<Matrix> list) throws Exception {
        Double valueOf = Double.valueOf(1.0d / list.size());
        HashMap hashMap = new HashMap();
        Iterator<Matrix> it = list.iterator();
        while (it.hasNext()) {
            hashMap.put(it.next(), valueOf);
        }
        return hashMap;
    }

    protected static Map<Matrix, Double> automatic(List<Matrix> list, Vector vector) throws Exception {
        logger.debug("building system to solve for weights");
        int[] find = MatrixUtils.find(vector, 1.0d);
        int[] find2 = MatrixUtils.find(vector, -1.0d);
        int length = find.length;
        int length2 = find2.length;
        int size = list.size();
        vector.size();
        DenseMatrix denseMatrix = new DenseMatrix(size + 1, size + 1);
        DenseVector denseVector = new DenseVector(size + 1);
        double d = (2.0d * length2) / (length + length2);
        double d2 = ((-2.0d) * length) / (length + length2);
        int i = length * (length - 1);
        int i2 = 2 * length * length2;
        double d3 = d * d;
        double d4 = d * d2;
        logger.debug(String.format("numPos: %d, numNeg: %d", Integer.valueOf(length), Integer.valueOf(length2)));
        logger.debug(String.format("posPosTarget: %f, posNegTarget: %f", Double.valueOf(d3), Double.valueOf(d4)));
        double d5 = 1.0d / (i + i2);
        denseVector.set(0, d5 * ((d3 * i) + (d4 * i2)));
        denseMatrix.set(0, 0, d5);
        Matrix[] matrixArr = new Matrix[size];
        Matrix[] matrixArr2 = new Matrix[size];
        for (int i3 = 0; i3 < size; i3++) {
            Matrix matrix = list.get(i3);
            matrixArr[i3] = new FlexCompColMatrix(Matrices.getSubMatrix(matrix, find, find));
            MatrixUtils.setDiagonalZero(matrixArr[i3]);
            matrixArr2[i3] = new FlexCompColMatrix(Matrices.getSubMatrix(matrix, find, find2));
            double sum = MatrixUtils.sum(matrixArr[i3]);
            double sum2 = MatrixUtils.sum(matrixArr2[i3]);
            denseVector.set(i3 + 1, (d3 * sum) + (2.0d * d4 * sum2));
            denseMatrix.set(i3 + 1, 0, d5 * (sum + (2.0d * sum2)));
            denseMatrix.set(0, i3 + 1, denseMatrix.get(i3 + 1, 0));
            for (int i4 = 0; i4 <= i3; i4++) {
                double elementMultiplySum = Constants.DISCRIMINANT_THRESHOLD + MatrixUtils.elementMultiplySum(matrixArr[i3], matrixArr[i4]) + (2.0d * MatrixUtils.elementMultiplySum(matrixArr2[i3], matrixArr2[i4]));
                denseMatrix.set(i3 + 1, i4 + 1, elementMultiplySum);
                denseMatrix.set(i4 + 1, i3 + 1, elementMultiplySum);
            }
        }
        Vector absRowSums = MatrixUtils.absRowSums(denseMatrix);
        int[] findGT = MatrixUtils.findGT(absRowSums, absRowSums.norm(Vector.Norm.Infinity) * EPSILON);
        logger.debug("full orig target vector: \n" + denseVector);
        Matrix copy = Matrices.getSubMatrix(denseMatrix, findGT, findGT).copy();
        Vector copy2 = Matrices.getSubVector(denseVector, findGT).copy();
        Vector vector2 = null;
        while (0 == 0) {
            logger.debug("solving for weights");
            vector2 = new DenseVector(copy2.size());
            DenseVector denseVector2 = new DenseVector(copy2.size());
            try {
                QR factorize = QR.factorize(copy);
                factorize.getQ().transMult(copy2, denseVector2);
                factorize.getR().solve(denseVector2, vector2);
                logger.debug("alpha: \n" + vector2);
                DenseVector denseVector3 = new DenseVector(vector2.size());
                copy.mult(vector2, denseVector3);
                denseVector3.add(-1.0d, copy2);
                logger.debug("check: \n" + denseVector3);
                logger.debug("KtT after previous solve: \n" + copy2);
                int size2 = vector2.size();
                int[] filter = MatrixUtils.filter(MatrixUtils.findGE(vector2, Constants.DISCRIMINANT_THRESHOLD + DELTA), 0);
                if (filter.length == 0) {
                    logger.warn("all networks eliminated, using averaging");
                    return average(list);
                }
                if (filter.length == size2 - 1) {
                    break;
                }
                int[] arrayJoin = MatrixUtils.arrayJoin(new int[]{0}, filter);
                copy = Matrices.getSubMatrix(copy, arrayJoin, arrayJoin).copy();
                copy2 = Matrices.getSubVector(copy2, arrayJoin).copy();
                findGT = MatrixUtils.subArray(findGT, arrayJoin);
            } catch (MatrixSingularException e) {
                logger.warn("automatic weighting failed (matrix singular), using average weighting");
                return average(list);
            }
        }
        HashMap hashMap = new HashMap();
        for (int i5 = 0; i5 < findGT.length; i5++) {
            if (findGT[i5] != 0) {
                hashMap.put(list.get(findGT[i5] - 1), Double.valueOf(vector2.get(i5)));
            }
        }
        return hashMap;
    }

    public static Matrix combine(Map<Matrix, Double> map) throws Exception {
        FlexCompColMatrix flexCompColMatrix = null;
        for (Matrix matrix : map.keySet()) {
            double doubleValue = map.get(matrix).doubleValue();
            if (flexCompColMatrix == null) {
                int numRows = matrix.numRows();
                flexCompColMatrix = new FlexCompColMatrix(numRows, numRows);
            }
            flexCompColMatrix.add(doubleValue, matrix);
        }
        return flexCompColMatrix;
    }
}
