package org.genemania.engine.core.integration.calculators;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.Matrix;
import no.uib.cipr.matrix.Vector;
import org.apache.log4j.Logger;
import org.genemania.engine.Constants;
import org.genemania.engine.cache.DataCache;
import org.genemania.engine.core.MakeKtT;
import org.genemania.engine.core.MatrixUtils;
import org.genemania.engine.core.data.CombinedNetwork;
import org.genemania.engine.core.data.Data;
import org.genemania.engine.core.integration.CombineNetworksOnly;
import org.genemania.engine.core.integration.MakeKtK2;
import org.genemania.engine.core.integration.PreCalculatedWeightSelection;
import org.genemania.engine.core.integration.Solver;
import org.genemania.engine.exception.CancellationException;
import org.genemania.engine.matricks.SymMatrix;
import org.genemania.exception.ApplicationException;
import org.genemania.util.ProgressReporter;

/* loaded from: input_file:org/genemania/engine/core/integration/calculators/AutomaticRelevanceCalculator.class */
public class AutomaticRelevanceCalculator extends AbstractNetworkWeightCalculator {
    private static Logger logger = Logger.getLogger(AbstractNetworkWeightCalculator.class);

    public AutomaticRelevanceCalculator(DataCache dataCache, Collection<Collection<Long>> collection, int i, Vector vector, ProgressReporter progressReporter) throws ApplicationException {
        super(dataCache, collection, i, vector, progressReporter);
    }

    public AutomaticRelevanceCalculator(String str, DataCache dataCache, Collection<Collection<Long>> collection, int i, Vector vector, ProgressReporter progressReporter) throws ApplicationException {
        super(str, dataCache, collection, i, vector, progressReporter);
    }

    @Override // org.genemania.engine.core.integration.INetworkWeightCalculator
    public void process() throws ApplicationException {
        boolean queryHasUserNetworks = queryHasUserNetworks();
        logger.info("1 automatic relevance");
        DenseMatrix ktK = getKtK("BP", queryHasUserNetworks);
        logger.info("2 retrieved KtK_BP");
        this.progress.setStatus(Constants.PROGRESS_WEIGHTING_MESSAGE);
        this.progress.setProgress(1);
        boolean z = false;
        if ((this.namespace == null || !queryHasUserNetworks) && ktK.numColumns() == this.networkIds.size() + 1) {
            try {
                getPrecomputedResultAllNetworks();
                z = true;
            } catch (ApplicationException e) {
                logger.debug("unable to fetch precomputed result");
            }
        }
        if (z) {
            return;
        }
        computeNewResult(queryHasUserNetworks);
    }

    void computeNewResult(boolean z) throws ApplicationException {
        Map<Long, Double> average;
        Matrix[] matrixArr = new DenseMatrix[3];
        matrixArr[0] = getKtT("BP", z);
        matrixArr[1] = getKtT("CC", z);
        matrixArr[2] = getKtT("MF", z);
        Matrix[] matrixArr2 = new DenseMatrix[3];
        matrixArr2[0] = getKtK("BP", z);
        matrixArr2[1] = getKtK("CC", z);
        matrixArr2[2] = getKtK("MF", z);
        CombinedNetwork[] combinedNetworkArr = new CombinedNetwork[3];
        ArrayList arrayList = new ArrayList();
        Map<Long, Integer> columnId = getColumnId(z);
        int i = 0;
        Iterator<Collection<Long>> it = this.networkIds.iterator();
        while (it.hasNext()) {
            Iterator<Long> it2 = it.next().iterator();
            while (it2.hasNext()) {
                long longValue = it2.next().longValue();
                arrayList.add(columnId.get(Long.valueOf(longValue)));
                this.IndexToNetworkIdMap.put(Integer.valueOf(i), Long.valueOf(longValue));
                i++;
            }
        }
        for (int i2 = 0; i2 < 3; i2++) {
            if (this.progress.isCanceled()) {
                throw new CancellationException();
            }
            CombinedNetwork combinedNetwork = new CombinedNetwork(Data.CORE, this.organismId, "on_the_fly:" + i2);
            combinedNetworkArr[i2] = combinedNetwork;
            matrixArr2[i2] = MakeKtK2.RemoveNetwork(matrixArr2[i2], arrayList);
            matrixArr[i2] = MakeKtT.RemoveNetwork(matrixArr[i2], arrayList);
            try {
                average = Solver.solve(matrixArr2[i2], MatrixUtils.extractColumnToVector(matrixArr[i2], 0), this.IndexToNetworkIdMap, this.progress);
            } catch (ApplicationException e) {
                logger.error("weighting calculation failed, falling back to average: " + e.getMessage());
                average = AverageByNetworkCalculator.average(this.IndexToNetworkIdMap);
            }
            SymMatrix combine = CombineNetworksOnly.combine(average, this.namespace, this.organismId, this.cache, this.progress);
            double elementMultiplySum = combine.elementMultiplySum(combine);
            combinedNetwork.setData(combine);
            combinedNetwork.setWtW(elementMultiplySum);
            combinedNetwork.setWeightMap(average);
        }
        try {
            int selectBranch = PreCalculatedWeightSelection.selectBranch(combinedNetworkArr, this.label);
            this.combinedMatrix = combinedNetworkArr[selectBranch].getData();
            this.weights = combinedNetworkArr[selectBranch].getWeightMap();
        } catch (ApplicationException e2) {
            logger.info("failed to select branch, falling back to average: " + e2.getMessage());
            this.weights = AverageByNetworkCalculator.average(this.IndexToNetworkIdMap);
            this.progress.setStatus(Constants.PROGRESS_COMBINING_MESSAGE);
            this.progress.setProgress(2);
            this.combinedMatrix = CombineNetworksOnly.combine(this.weights, this.namespace, this.organismId, this.cache, this.progress);
        }
    }

    void getPrecomputedResultAllNetworks() throws ApplicationException {
        CombinedNetwork[] combinedNetworkArr = {this.cache.getCombinedNetwork(Data.CORE, this.organismId, "BP"), this.cache.getCombinedNetwork(Data.CORE, this.organismId, "CC"), this.cache.getCombinedNetwork(Data.CORE, this.organismId, "MF")};
        if (this.progress.isCanceled()) {
            throw new CancellationException();
        }
        logger.info("4 retrived DotProduct_BP/CC/MF, calculated wtw");
        if (this.progress.isCanceled()) {
            throw new CancellationException();
        }
        try {
            int selectBranch = PreCalculatedWeightSelection.selectBranch(combinedNetworkArr, this.label);
            logger.info("5 selected branch " + selectBranch);
            this.combinedMatrix = combinedNetworkArr[selectBranch].getData();
            this.weights = combinedNetworkArr[selectBranch].getWeightMap();
            logger.info("6 retrieved weight map");
        } catch (ApplicationException e) {
            logger.info("failed to select branch, falling back to average: " + e.getMessage());
            this.weights = AverageByNetworkCalculator.average(this.IndexToNetworkIdMap);
            this.progress.setStatus(Constants.PROGRESS_COMBINING_MESSAGE);
            this.progress.setProgress(2);
            this.combinedMatrix = CombineNetworksOnly.combine(this.weights, this.namespace, this.organismId, this.cache, this.progress);
        }
    }
}
