package net.maizegenetics.analysis.imputation;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.MinMaxPriorityQueue;
import com.google.common.collect.Multimap;
import com.google.common.collect.Range;
import java.awt.Frame;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.concurrent.atomic.LongAdder;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javax.swing.ImageIcon;
import net.maizegenetics.analysis.popgen.LDResult;
import net.maizegenetics.analysis.popgen.LinkageDisequilibrium;
import net.maizegenetics.dna.map.Position;
import net.maizegenetics.dna.map.PositionList;
import net.maizegenetics.dna.map.PositionListBuilder;
import net.maizegenetics.dna.snp.FilterGenotypeTable;
import net.maizegenetics.dna.snp.GenotypeTable;
import net.maizegenetics.dna.snp.GenotypeTableBuilder;
import net.maizegenetics.dna.snp.GenotypeTableUtils;
import net.maizegenetics.plugindef.AbstractPlugin;
import net.maizegenetics.plugindef.DataSet;
import net.maizegenetics.plugindef.Datum;
import net.maizegenetics.plugindef.GeneratePluginCode;
import net.maizegenetics.plugindef.PluginParameter;
import net.maizegenetics.taxa.Taxon;
import net.maizegenetics.util.Tuple;
import org.apache.log4j.Logger;

/* loaded from: input_file:net/maizegenetics/analysis/imputation/LDKNNiImputationPlugin.class */
public class LDKNNiImputationPlugin extends AbstractPlugin {
    private PluginParameter<Integer> highLDSSites;
    private PluginParameter<Integer> knnTaxa;
    private PluginParameter<Integer> maxDistance;
    private static final Logger myLogger = Logger.getLogger(LDKNNiImputationPlugin.class);

    public LDKNNiImputationPlugin() {
        super(null, false);
        this.highLDSSites = new PluginParameter.Builder("highLDSSites", 30, Integer.class).range(Range.closed(2, 2000)).guiName("High LD Sites").description("Number of sites in high LD to use in imputation").build();
        this.knnTaxa = new PluginParameter.Builder("knnTaxa", 10, Integer.class).range(Range.closed(2, 200)).guiName("Number of nearest neighbors").description("Number of neighbors to use in imputation").build();
        this.maxDistance = new PluginParameter.Builder("maxLDDistance", 10000000, Integer.class).guiName("Max distance between site to find LD").description("Maximum physical distance between sites to search for LD (-1 for no distance cutoff - unlinked chromosomes will be tested)").build();
    }

    public LDKNNiImputationPlugin(Frame frame, boolean z) {
        super(frame, z);
        this.highLDSSites = new PluginParameter.Builder("highLDSSites", 30, Integer.class).range(Range.closed(2, 2000)).guiName("High LD Sites").description("Number of sites in high LD to use in imputation").build();
        this.knnTaxa = new PluginParameter.Builder("knnTaxa", 10, Integer.class).range(Range.closed(2, 200)).guiName("Number of nearest neighbors").description("Number of neighbors to use in imputation").build();
        this.maxDistance = new PluginParameter.Builder("maxLDDistance", 10000000, Integer.class).guiName("Max distance between site to find LD").description("Maximum physical distance between sites to search for LD (-1 for no distance cutoff - unlinked chromosomes will be tested)").build();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // net.maizegenetics.plugindef.AbstractPlugin
    public void preProcessParameters(DataSet dataSet) {
        if (dataSet.getDataOfType(GenotypeTable.class).size() != 1) {
            throw new IllegalArgumentException("LDKNNiImputationPlugin: preProcessParameters: Please select one Genotype Table.");
        }
    }

    @Override // net.maizegenetics.plugindef.AbstractPlugin, net.maizegenetics.plugindef.Plugin
    public DataSet processData(DataSet dataSet) {
        Datum datum = dataSet.getDataOfType(GenotypeTable.class).get(0);
        GenotypeTable genotypeTable = (GenotypeTable) datum.getData();
        Multimap<Position, Position> highLDMap = getHighLDMap(genotypeTable, highLDSSites().intValue());
        System.out.println("LD calculated");
        GenotypeTableBuilder siteIncremental = GenotypeTableBuilder.getSiteIncremental(genotypeTable.taxa());
        long nanoTime = System.nanoTime();
        LongAdder longAdder = new LongAdder();
        IntStream.range(0, genotypeTable.numberOfSites()).parallel().forEach(i -> {
            Position position = genotypeTable.positions().get(i);
            PositionList positionListBuilder = PositionListBuilder.getInstance(new ArrayList(highLDMap.get(position)));
            byte[] genotypeAllTaxa = genotypeTable.genotypeAllTaxa(i);
            byte[] bArr = new byte[genotypeAllTaxa.length];
            if (genotypeTable.isPolymorphic(i)) {
                GenotypeTable genotypeCopyInstance = GenotypeTableBuilder.getGenotypeCopyInstance(FilterGenotypeTable.getInstance(genotypeTable, positionListBuilder));
                int numberOfSites = genotypeCopyInstance.numberOfSites();
                double[] array = IntStream.range(0, genotypeCopyInstance.numberOfTaxa()).sequential().mapToDouble(i -> {
                    return genotypeCopyInstance.totalNonMissingForTaxon(i) / numberOfSites;
                }).toArray();
                for (int i2 = 0; i2 < genotypeAllTaxa.length; i2++) {
                    bArr[i2] = genotypeAllTaxa[i2];
                    if (genotypeAllTaxa[i2] == -1) {
                        Multimap<Double, Byte> closestNonMissingTaxa = getClosestNonMissingTaxa(genotypeTable.taxa().get(i2), genotypeTable, genotypeCopyInstance, position, array, knnTaxa().intValue());
                        if (closestNonMissingTaxa.isEmpty()) {
                            bArr[i2] = -1;
                        } else {
                            bArr[i2] = impute(closestNonMissingTaxa, highLDSSites().intValue());
                        }
                    }
                }
            } else {
                byte diploidValue = GenotypeTableUtils.getDiploidValue(genotypeTable.majorAllele(i), genotypeTable.majorAllele(i));
                for (int i3 = 0; i3 < bArr.length; i3++) {
                    bArr[i3] = genotypeAllTaxa[i3] == -1 ? diploidValue : genotypeAllTaxa[i3];
                }
            }
            siteIncremental.addSite(position, bArr);
            if ((i + 1) % 100 == 0) {
                longAdder.add(100L);
                fireProgress(Integer.valueOf(33 + (((int) (66 * longAdder.longValue())) / genotypeTable.numberOfSites())));
                System.out.println(longAdder.longValue() + Taxon.DELIMITER + (((System.nanoTime() - nanoTime) / 1000000) / longAdder.longValue()));
            }
        });
        return new DataSet(new Datum(datum.getName() + "_KNNimp", siteIncremental.build(), "Imputed genotypes by KNN imputation"), this);
    }

    private Multimap<Double, Byte> getClosestNonMissingTaxa(Taxon taxon, GenotypeTable genotypeTable, GenotypeTable genotypeTable2, Position position, double[] dArr, int i) {
        int indexOf = genotypeTable.positions().indexOf(position);
        int indexOf2 = genotypeTable.taxa().indexOf(taxon);
        byte[] genotypeAllSites = genotypeTable2.genotypeAllSites(indexOf2);
        MinMaxPriorityQueue minMaxPriorityQueue = (MinMaxPriorityQueue) IntStream.range(0, genotypeTable.numberOfTaxa()).filter(i2 -> {
            return i2 != indexOf2;
        }).filter(i3 -> {
            return (dArr[i3] * dArr[indexOf2]) * ((double) genotypeTable2.numberOfSites()) > 10.0d;
        }).filter(i4 -> {
            return genotypeTable.genotype(i4, indexOf) != -1;
        }).mapToObj(i5 -> {
            return new Tuple(Double.valueOf(dist(genotypeAllSites, genotypeTable2.genotypeAllSites(i5), 10)[0]), Byte.valueOf(genotypeTable.genotype(i5, indexOf)));
        }).filter(tuple -> {
            return !Double.isNaN(((Double) tuple.x).doubleValue());
        }).collect(Collectors.toCollection(() -> {
            return MinMaxPriorityQueue.maximumSize(i).create();
        }));
        ArrayListMultimap create = ArrayListMultimap.create();
        minMaxPriorityQueue.stream().forEach(tuple2 -> {
            create.put(tuple2.x, tuple2.y);
        });
        return create;
    }

    private Multimap<Position, Position> getHighLDMap(GenotypeTable genotypeTable, int i) {
        ArrayListMultimap create = ArrayListMultimap.create();
        int numberOfSites = genotypeTable.numberOfSites();
        LongAdder longAdder = new LongAdder();
        IntStream.range(0, genotypeTable.numberOfSites()).parallel().forEach(i2 -> {
            MinMaxPriorityQueue create2 = MinMaxPriorityQueue.orderedBy(LDResult.byR2Ordering.reverse()).maximumSize(i).create();
            for (int i2 = 0; i2 < numberOfSites; i2++) {
                if (i2 != i2 && (maxDistance().intValue() <= -1 || Math.abs(genotypeTable.chromosomalPosition(i2) - genotypeTable.chromosomalPosition(i2)) <= maxDistance().intValue())) {
                    LDResult calculateBitLDForHaplotype = LinkageDisequilibrium.calculateBitLDForHaplotype(false, 20, genotypeTable, i2, i2);
                    if (!Double.isNaN(calculateBitLDForHaplotype.r2())) {
                        create2.add(calculateBitLDForHaplotype);
                    }
                }
            }
            ArrayList arrayList = new ArrayList();
            Iterator it = create2.iterator();
            while (it.hasNext()) {
                arrayList.add(genotypeTable.positions().get(((LDResult) it.next()).site2()));
            }
            create.putAll(genotypeTable.positions().get(i2), arrayList);
            if ((i2 + 1) % 1000 == 0) {
                longAdder.add(1000L);
                fireProgress(Integer.valueOf(((int) (33 * longAdder.longValue())) / numberOfSites));
                System.out.println(longAdder.longValue());
            }
        });
        return create;
    }

    static byte impute(Multimap<Double, Byte> multimap, int i) {
        double[] dArr = new double[256];
        multimap.entries().forEach(entry -> {
            int byteValue = ((Byte) entry.getValue()).byteValue() + 128;
            dArr[byteValue] = dArr[byteValue] + (1.0d / (1.0d + (i * ((Double) entry.getKey()).doubleValue())));
        });
        int i2 = 0;
        double d = dArr[0];
        for (int i3 = 1; i3 < 256; i3++) {
            if (dArr[i3] > d) {
                d = dArr[i3];
                i2 = i3;
            }
        }
        return (byte) (i2 - 128);
    }

    @Override // net.maizegenetics.plugindef.AbstractPlugin, net.maizegenetics.plugindef.Plugin
    public String getCitation() {
        return "Daniel Money, Kyle Gardner, Heidi Schwaninger, Gan-Yuan Zhong, Sean Myles. (In Review)  LinkImpute: fast and accurate genotype imputation for non-model organisms";
    }

    @Override // net.maizegenetics.plugindef.Plugin
    public ImageIcon getIcon() {
        return null;
    }

    @Override // net.maizegenetics.plugindef.Plugin
    public String getButtonName() {
        return "LD KNNi Imputation";
    }

    @Override // net.maizegenetics.plugindef.Plugin
    public String getToolTipText() {
        return "LD KNNi Imputation";
    }

    public static void main(String[] strArr) {
        GeneratePluginCode.generate(LDKNNiImputationPlugin.class);
    }

    public GenotypeTable runPlugin(DataSet dataSet) {
        return (GenotypeTable) performFunction(dataSet).getData(0).getData();
    }

    public Integer highLDSSites() {
        return this.highLDSSites.value();
    }

    public LDKNNiImputationPlugin highLDSSites(Integer num) {
        this.highLDSSites = new PluginParameter<>(this.highLDSSites, num);
        return this;
    }

    public Integer knnTaxa() {
        return this.knnTaxa.value();
    }

    public LDKNNiImputationPlugin knnTaxa(Integer num) {
        this.knnTaxa = new PluginParameter<>(this.knnTaxa, num);
        return this;
    }

    public Integer maxDistance() {
        return this.maxDistance.value();
    }

    public LDKNNiImputationPlugin maxDistance(Integer num) {
        this.maxDistance = new PluginParameter<>(this.maxDistance, num);
        return this;
    }

    public static double[] dist(byte[] bArr, byte[] bArr2, int i) {
        int i2 = 0;
        int i3 = 0;
        for (int i4 = 0; i4 < bArr.length; i4++) {
            byte unphasedSortedDiploidValue = GenotypeTableUtils.getUnphasedSortedDiploidValue(bArr[i4]);
            byte unphasedSortedDiploidValue2 = GenotypeTableUtils.getUnphasedSortedDiploidValue(bArr2[i4]);
            if (unphasedSortedDiploidValue != -1 && unphasedSortedDiploidValue2 != -1) {
                i3++;
                if (unphasedSortedDiploidValue != unphasedSortedDiploidValue2) {
                    i2 = (GenotypeTableUtils.isHeterozygous(unphasedSortedDiploidValue) || GenotypeTableUtils.isHeterozygous(unphasedSortedDiploidValue2)) ? i2 + 1 : i2 + 2;
                }
            }
        }
        return i3 < i ? new double[]{Double.NaN, i3} : new double[]{i2 / (2 * i3), i3};
    }
}
