package org.tribuo.anomaly.example;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.tribuo.ConfigurableDataSource;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.MutableDataset;
import org.tribuo.OutputFactory;
import org.tribuo.anomaly.AnomalyFactory;
import org.tribuo.anomaly.Event;
import org.tribuo.impl.ArrayExample;
import org.tribuo.provenance.ConfiguredDataSourceProvenance;
import org.tribuo.provenance.DataSourceProvenance;

/* loaded from: input_file:org/tribuo/anomaly/example/GaussianAnomalyDataSource.class */
public final class GaussianAnomalyDataSource implements ConfigurableDataSource<Event> {
    private static final AnomalyFactory factory = new AnomalyFactory();
    private static final String[] allFeatureNames = {"A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"};

    @Config(mandatory = true, description = "The number of samples to draw.")
    private int numSamples;

    @Config(description = "Means of the expected events.")
    private double[] expectedMeans;

    @Config(description = "Variances of the expected events.")
    private double[] expectedVariances;

    @Config(description = "Means of the anomalous events.")
    private double[] anomalousMeans;

    @Config(description = "Variances of the anomalous events.")
    private double[] anomalousVariances;

    @Config(description = "The RNG seed.")
    private long seed;

    @Config(mandatory = true, description = "The fraction of anomalous events.")
    private float fractionAnomalous;
    private List<Example<Event>> examples;

    /* loaded from: input_file:org/tribuo/anomaly/example/GaussianAnomalyDataSource$GaussianAnomalyDataSourceProvenance.class */
    public static final class GaussianAnomalyDataSourceProvenance extends SkeletalConfiguredObjectProvenance implements ConfiguredDataSourceProvenance {
        private static final long serialVersionUID = 1;

        GaussianAnomalyDataSourceProvenance(GaussianAnomalyDataSource gaussianAnomalyDataSource) {
            super(gaussianAnomalyDataSource, "DataSource");
        }

        public GaussianAnomalyDataSourceProvenance(Map<String, Provenance> map) {
            this(extractProvenanceInfo(map));
        }

        private GaussianAnomalyDataSourceProvenance(SkeletalConfiguredObjectProvenance.ExtractedInfo extractedInfo) {
            super(extractedInfo);
        }

        protected static SkeletalConfiguredObjectProvenance.ExtractedInfo extractProvenanceInfo(Map<String, Provenance> map) {
            HashMap hashMap = new HashMap(map);
            return new SkeletalConfiguredObjectProvenance.ExtractedInfo(ObjectProvenance.checkAndExtractProvenance(hashMap, "class-name", StringProvenance.class, GaussianAnomalyDataSourceProvenance.class.getSimpleName()).getValue(), ObjectProvenance.checkAndExtractProvenance(hashMap, "host-short-name", StringProvenance.class, GaussianAnomalyDataSourceProvenance.class.getSimpleName()).getValue(), hashMap, Collections.emptyMap());
        }
    }

    private GaussianAnomalyDataSource() {
        this.expectedMeans = new double[]{1.0d, 2.0d, 1.0d, 2.0d, 5.0d};
        this.expectedVariances = new double[]{1.0d, 0.5d, 0.25d, 1.0d, 0.1d};
        this.anomalousMeans = new double[]{-2.0d, 2.0d, -2.0d, 2.0d, -10.0d};
        this.anomalousVariances = new double[]{1.0d, 0.5d, 0.25d, 1.0d, 0.1d};
        this.seed = 12345L;
        this.fractionAnomalous = 0.3f;
    }

    public GaussianAnomalyDataSource(int i, float f, long j) {
        this.expectedMeans = new double[]{1.0d, 2.0d, 1.0d, 2.0d, 5.0d};
        this.expectedVariances = new double[]{1.0d, 0.5d, 0.25d, 1.0d, 0.1d};
        this.anomalousMeans = new double[]{-2.0d, 2.0d, -2.0d, 2.0d, -10.0d};
        this.anomalousVariances = new double[]{1.0d, 0.5d, 0.25d, 1.0d, 0.1d};
        this.seed = 12345L;
        this.fractionAnomalous = 0.3f;
        this.numSamples = i;
        this.fractionAnomalous = f;
        this.seed = j;
        postConfig();
    }

    public GaussianAnomalyDataSource(int i, double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4, float f, long j) {
        this.expectedMeans = new double[]{1.0d, 2.0d, 1.0d, 2.0d, 5.0d};
        this.expectedVariances = new double[]{1.0d, 0.5d, 0.25d, 1.0d, 0.1d};
        this.anomalousMeans = new double[]{-2.0d, 2.0d, -2.0d, 2.0d, -10.0d};
        this.anomalousVariances = new double[]{1.0d, 0.5d, 0.25d, 1.0d, 0.1d};
        this.seed = 12345L;
        this.fractionAnomalous = 0.3f;
        this.numSamples = i;
        this.expectedMeans = dArr;
        this.expectedVariances = dArr2;
        this.anomalousMeans = dArr3;
        this.anomalousVariances = dArr4;
        this.fractionAnomalous = f;
        this.seed = j;
        postConfig();
    }

    public void postConfig() {
        if (this.numSamples < 1) {
            throw new PropertyException("", "numSamples", "numSamples must be positive, found " + this.numSamples);
        }
        if (this.expectedMeans.length > allFeatureNames.length || this.expectedMeans.length == 0) {
            throw new PropertyException("", "expectedMeans", "Must have 1-26 features, found " + this.expectedMeans.length);
        }
        if (this.expectedMeans.length != this.expectedVariances.length) {
            throw new PropertyException("", "expectedMeans", "Must supply the same number of expected means and variances. expectedMeans.length = " + this.expectedMeans.length + " expectedVariances.length = " + this.expectedVariances.length);
        }
        if (this.anomalousMeans.length != this.anomalousVariances.length) {
            throw new PropertyException("", "anomalousMeans", "Must supply the same number of anomalous means and variances. anomalousMeans.length = " + this.anomalousMeans.length + " anomalousVariances.length = " + this.anomalousVariances.length);
        }
        if (this.fractionAnomalous < 0.0f || this.fractionAnomalous > 1.0f) {
            throw new PropertyException("", "fractionAnomalous", "fractionAnomalous must be between 0.0 and 1.0, found " + this.fractionAnomalous);
        }
        if (this.fractionAnomalous != 0.0d && this.anomalousMeans.length != this.expectedMeans.length) {
            throw new PropertyException("", "anomalousMeans", "When sampling anomalous data there must be the same number of anomalous features as expected features. anomalousMeans.length = " + this.anomalousMeans.length + ", expectedMeans.length = " + this.expectedMeans.length);
        }
        for (int i = 0; i < this.anomalousVariances.length; i++) {
            if (this.anomalousVariances[i] < 1.0E-10d) {
                throw new PropertyException("", "anomalousVariances", "Variances must be positive, found " + Arrays.toString(this.anomalousVariances));
            }
            if (this.expectedVariances[i] < 1.0E-10d) {
                throw new PropertyException("", "expectedVariances", "Variances must be positive, found " + Arrays.toString(this.expectedVariances));
            }
        }
        String[] strArr = (String[]) Arrays.copyOf(allFeatureNames, this.expectedMeans.length);
        Random random = new Random(this.seed);
        ArrayList arrayList = new ArrayList(this.numSamples);
        for (int i2 = 0; i2 < this.numSamples; i2++) {
            if (random.nextDouble() < this.fractionAnomalous) {
                arrayList.add(new ArrayExample(AnomalyFactory.ANOMALOUS_EVENT, generateFeatures(random, strArr, this.anomalousMeans, this.anomalousVariances)));
            } else {
                arrayList.add(new ArrayExample(AnomalyFactory.EXPECTED_EVENT, generateFeatures(random, strArr, this.expectedMeans, this.expectedVariances)));
            }
        }
        this.examples = Collections.unmodifiableList(arrayList);
    }

    public OutputFactory<Event> getOutputFactory() {
        return factory;
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public DataSourceProvenance m18getProvenance() {
        return new GaussianAnomalyDataSourceProvenance(this);
    }

    public Iterator<Example<Event>> iterator() {
        return this.examples.iterator();
    }

    private static List<Feature> generateFeatures(Random random, String[] strArr, double[] dArr, double[] dArr2) {
        if (strArr.length != dArr.length || strArr.length != dArr2.length) {
            throw new IllegalArgumentException("Names, means and variances must be the same length");
        }
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < strArr.length; i++) {
            arrayList.add(new Feature(strArr[i], (random.nextGaussian() * Math.sqrt(dArr2[i])) + dArr[i]));
        }
        return arrayList;
    }

    public static MutableDataset<Event> generateDataset(int i, double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4, float f, long j) {
        return new MutableDataset<>(new GaussianAnomalyDataSource(i, dArr, dArr2, dArr3, dArr4, f, j));
    }
}
