package org.genemania.engine.core.integration.calculators;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.Vector;
import org.apache.log4j.Logger;
import org.genemania.engine.Constants;
import org.genemania.engine.cache.DataCache;
import org.genemania.engine.core.MakeKtT;
import org.genemania.engine.core.MatrixUtils;
import org.genemania.engine.core.data.CombinedNetwork;
import org.genemania.engine.core.data.Data;
import org.genemania.engine.core.integration.CombineNetworksOnly;
import org.genemania.engine.core.integration.MakeKtK2;
import org.genemania.engine.core.integration.Solver;
import org.genemania.engine.exception.WeightingFailedException;
import org.genemania.exception.ApplicationException;
import org.genemania.util.ProgressReporter;

/* loaded from: input_file:org/genemania/engine/core/integration/calculators/BranchSpecificCalculator.class */
public class BranchSpecificCalculator extends AbstractNetworkWeightCalculator {
    private static Logger logger = Logger.getLogger(BranchSpecificCalculator.class);
    Constants.CombiningMethod method;
    public static final String PARAM_KEY_FORMAT = "%s-%s";

    public BranchSpecificCalculator(DataCache dataCache, Collection<Collection<Long>> collection, int i, Vector vector, Constants.CombiningMethod combiningMethod, ProgressReporter progressReporter) throws ApplicationException {
        super(dataCache, collection, i, vector, progressReporter);
        this.method = combiningMethod;
    }

    public BranchSpecificCalculator(String str, DataCache dataCache, Collection<Collection<Long>> collection, int i, Vector vector, Constants.CombiningMethod combiningMethod, ProgressReporter progressReporter) throws ApplicationException {
        super(str, dataCache, collection, i, vector, progressReporter);
        this.method = combiningMethod;
    }

    @Override // org.genemania.engine.core.integration.INetworkWeightCalculator
    public void process() throws ApplicationException {
        this.progress.setStatus(Constants.PROGRESS_WEIGHTING_MESSAGE);
        this.progress.setProgress(1);
        boolean queryHasUserNetworks = queryHasUserNetworks();
        DenseMatrix ktK = getKtK(this.method.toString(), queryHasUserNetworks);
        boolean z = false;
        if ((this.namespace == null || !queryHasUserNetworks) && ktK.numColumns() == this.networkIds.size() + 1) {
            try {
                getPrecomputedResultAllNetworks();
                z = true;
            } catch (ApplicationException e) {
                logger.debug("unable to fetch precomputed result");
            }
        }
        if (z) {
            return;
        }
        computeNewResult(queryHasUserNetworks, ktK);
    }

    void computeNewResult(boolean z, DenseMatrix denseMatrix) throws ApplicationException {
        DenseMatrix ktT = getKtT(this.method.toString(), z);
        ArrayList arrayList = new ArrayList();
        Map<Long, Integer> columnId = getColumnId(z);
        int i = 0;
        Iterator<Collection<Long>> it = this.networkIds.iterator();
        while (it.hasNext()) {
            Iterator<Long> it2 = it.next().iterator();
            while (it2.hasNext()) {
                long longValue = it2.next().longValue();
                arrayList.add(columnId.get(Long.valueOf(longValue)));
                this.IndexToNetworkIdMap.put(Integer.valueOf(i), Long.valueOf(longValue));
                i++;
            }
        }
        try {
            this.weights = Solver.solve(MakeKtK2.RemoveNetwork(denseMatrix, arrayList), MatrixUtils.extractColumnToVector(MakeKtT.RemoveNetwork(ktT, arrayList), 0), this.IndexToNetworkIdMap, this.progress);
        } catch (WeightingFailedException e) {
            logger.error("weighting calculation failed, falling back to average: " + e.getMessage());
            this.weights = AverageByNetworkCalculator.average(this.IndexToNetworkIdMap);
        }
        this.progress.setStatus(Constants.PROGRESS_COMBINING_MESSAGE);
        this.progress.setProgress(2);
        this.combinedMatrix = CombineNetworksOnly.combine(this.weights, this.namespace, this.organismId, this.cache, this.progress);
    }

    void getPrecomputedResultAllNetworks() throws ApplicationException {
        CombinedNetwork combinedNetwork = this.cache.getCombinedNetwork(Data.CORE, this.organismId, this.method.toString());
        this.progress.setStatus(Constants.PROGRESS_COMBINING_MESSAGE);
        this.progress.setProgress(2);
        this.weights = combinedNetwork.getWeightMap();
        this.combinedMatrix = combinedNetwork.getData();
    }

    @Override // org.genemania.engine.core.integration.calculators.AbstractNetworkWeightCalculator, org.genemania.engine.core.integration.INetworkWeightCalculator
    public String getParameterKey() throws ApplicationException {
        return String.format("%s-%s", this.method.toString(), formattedNetworkList(this.networkIds));
    }
}
