package org.nd4j.examples;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.BlasWrapper;
import org.nd4j.linalg.factory.NDArrayFactory;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jblas.JblasNDArrayFactory;
import org.nd4j.linalg.jcublas.JCublasNDArrayFactory;
import org.nd4j.linalg.jcublas.JCublasWrapper;

/* loaded from: input_file:org/nd4j/examples/MultiINDArrayInterop.class */
public class MultiINDArrayInterop {
    static NDArrayFactory jblas;
    static NDArrayFactory jcublas;
    static BlasWrapper wrapper;
    static BlasWrapper jcublasWrapper;
    static final /* synthetic */ boolean $assertionsDisabled;

    public static void main(String[] strArr) {
        INDArray linspace = jblas.linspace(1, 8, 8);
        INDArray linspace2 = jcublas.linspace(1, 8, 8);
        setJblas();
        INDArray transpose = linspace.transpose();
        setJcublas();
        INDArray transpose2 = linspace2.transpose();
        setJblas();
        INDArray mmul = linspace.mmul(transpose);
        setJcublas();
        INDArray mmul2 = transpose2.mmul(transpose2);
        if (!$assertionsDisabled && !mmul.equals(mmul2)) {
            throw new AssertionError();
        }
        setJblas();
        INDArray reshape = linspace.reshape(2, 4);
        INDArray transpose3 = reshape.transpose();
        INDArray mmul3 = transpose3.mmul(reshape);
        setJcublas();
        INDArray reshape2 = linspace2.reshape(2, 4);
        INDArray transpose4 = reshape.transpose();
        if (!$assertionsDisabled && !transpose3.equals(transpose4)) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && !reshape2.equals(reshape)) {
            throw new AssertionError();
        }
        INDArray mmul4 = transpose4.mmul(reshape2);
        if (!$assertionsDisabled && !mmul3.equals(mmul4)) {
            throw new AssertionError();
        }
        setJblas();
        INDArray create = jblas.create(new float[]{1.0f, 2.0f, 3.0f, 4.0f}, new int[]{2, 2});
        setJcublas();
        INDArray create2 = jcublas.create(new float[]{1.0f, 2.0f, 3.0f, 4.0f}, new int[]{2, 2});
        setJblas();
        INDArray transpose5 = create.transpose();
        setJcublas();
        INDArray transpose6 = create2.transpose();
        if (!$assertionsDisabled && !create.equals(create2)) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && !transpose5.equals(transpose6)) {
            throw new AssertionError();
        }
        setJblas();
        double dot = Nd4j.getBlasWrapper().dot(linspace, linspace);
        setJcublas();
        double dot2 = Nd4j.getBlasWrapper().dot(linspace2, linspace2);
        if (!$assertionsDisabled && dot != dot2) {
            throw new AssertionError();
        }
    }

    public static void setJcublas() {
        Nd4j.setFactory(jcublas);
        Nd4j.setBlasWrapper(jcublasWrapper);
    }

    public static void setJblas() {
        Nd4j.setFactory(jblas);
        Nd4j.setBlasWrapper(wrapper);
    }

    static {
        $assertionsDisabled = !MultiINDArrayInterop.class.desiredAssertionStatus();
        jblas = new JblasNDArrayFactory("double", 'f');
        jcublas = new JCublasNDArrayFactory("double", 'f');
        wrapper = new org.nd4j.linalg.jblas.BlasWrapper();
        jcublasWrapper = new JCublasWrapper();
    }
}
