package org.mitre.jcarafe.crf;

import org.mitre.jcarafe.crf.SemiCrf;
import scala.Array$;
import scala.Predef$;
import scala.collection.Seq;
import scala.collection.immutable.IndexedSeq;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;
import scala.runtime.TraitSetter;

/* compiled from: SemiCrf.scala */
@ScalaSignature(bytes = "\u0006\u0001\r3Q!\u0001\u0002\u0002\u0002-\u0011A\u0002R3og\u0016\u001cV-\\5De\u001aT!a\u0001\u0003\u0002\u0007\r\u0014hM\u0003\u0002\u0006\r\u00059!nY1sC\u001a,'BA\u0004\t\u0003\u0015i\u0017\u000e\u001e:f\u0015\u0005I\u0011aA8sO\u000e\u00011c\u0001\u0001\r!A\u0011QBD\u0007\u0002\u0005%\u0011qB\u0001\u0002\t\t\u0016t7/Z\"sMB\u0011Q\"E\u0005\u0003%\t\u0011qaU3nS\u000e\u0013h\r\u0003\u0005\u0015\u0001\t\u0005\t\u0015!\u0003\u0016\u0003\rqGn\u001d\t\u0003-ei\u0011a\u0006\u0006\u00021\u0005)1oY1mC&\u0011!d\u0006\u0002\u0004\u0013:$\b\u0002\u0003\u000f\u0001\u0005\u0003\u0005\u000b\u0011B\u000b\u0002\u0007947\u000f\u0003\u0005\u001f\u0001\t\u0005\t\u0015!\u0003\u0016\u0003\u001d\u0019XmZ*ju\u0016D\u0001\u0002\t\u0001\u0003\u0002\u0003\u0006I!I\u0001\u0007OB\u0013\u0018n\u001c:\u0011\u0005Y\u0011\u0013BA\u0012\u0018\u0005\u0019!u.\u001e2mK\")Q\u0005\u0001C\u0001M\u00051A(\u001b8jiz\"Ra\n\u0015*U-\u0002\"!\u0004\u0001\t\u000bQ!\u0003\u0019A\u000b\t\u000bq!\u0003\u0019A\u000b\t\u000by!\u0003\u0019A\u000b\t\u000b\u0001\"\u0003\u0019A\u0011\t\u000b5\u0002A\u0011\t\u0018\u0002\u0015I,w-\u001e7be&TX\rF\u0001\"\u0011\u0015\u0001\u0004\u0001\"\u00152\u0003-1wN]<be\u0012\u0004\u0016m]:\u0015\u0005\u0005\u0012\u0004\"B\u001a0\u0001\u0004!\u0014\u0001B5tKF\u00042!\u000e\u001e=\u001b\u00051$BA\u001c9\u0003%IW.\\;uC\ndWM\u0003\u0002:/\u0005Q1m\u001c7mK\u000e$\u0018n\u001c8\n\u0005m2$AC%oI\u0016DX\rZ*fcB\u0011Q\"P\u0005\u0003}\t\u0011\u0001#\u00112tiJ\f7\r^%ogR\fgnY3\t\u000b\u0001\u0003A\u0011I!\u0002\u0013\u001d\u0014\u0018\rZ(g'\u0016\fHCA\u0011C\u0011\u0015\u0019t\b1\u00015\u0001")
/* loaded from: input_file:org/mitre/jcarafe/crf/DenseSemiCrf.class */
public abstract class DenseSemiCrf extends DenseCrf implements SemiCrf {
    private final int nfs;
    private final int segSize;
    private double[] t2;

    @Override // org.mitre.jcarafe.crf.SemiCrf
    public double[] t2() {
        return this.t2;
    }

    @Override // org.mitre.jcarafe.crf.SemiCrf
    @TraitSetter
    public void t2_$eq(double[] dArr) {
        this.t2 = dArr;
    }

    @Override // org.mitre.jcarafe.crf.SemiCrf
    public double logSumExp(double d, double d2) {
        return SemiCrf.Cclass.logSumExp(this, d, d2);
    }

    @Override // org.mitre.jcarafe.crf.SemiCrf
    public void setArrayTo(double[] dArr, double d) {
        SemiCrf.Cclass.setArrayTo(this, dArr, d);
    }

    @Override // org.mitre.jcarafe.crf.SemiCrf
    public void matrixMultLog(double[][] dArr, double[] dArr2, double[] dArr3, double d, double d2, boolean z) {
        SemiCrf.Cclass.matrixMultLog(this, dArr, dArr2, dArr3, d, d2, z);
    }

    @Override // org.mitre.jcarafe.crf.SemiCrf
    public void computeScoresBackwards(Seq<Seq<Feature>> seq, boolean z) {
        SemiCrf.Cclass.computeScoresBackwards(this, seq, z);
    }

    @Override // org.mitre.jcarafe.crf.Crf, org.mitre.jcarafe.crf.GeneralizedEMCrf
    public void backwardPass(Seq<AbstractInstance> seq) {
        SemiCrf.Cclass.backwardPass(this, seq);
    }

    @Override // org.mitre.jcarafe.crf.DenseCrf
    public double regularize() {
        double d = 0.0d;
        for (int i = 0; i < lambdas().length; i++) {
            double d2 = lambdas()[i];
            gradient()[i] = d2 * invSigSqr();
            featureExpectations()[i] = 0.0d;
            d += ((d2 * d2) * invSigSqr()) / 2.0d;
        }
        return d;
    }

    @Override // org.mitre.jcarafe.crf.DenseCrf, org.mitre.jcarafe.crf.Crf
    public double forwardPass(IndexedSeq<AbstractInstance> indexedSeq) {
        double d = 0.0d;
        setArrayTo(alpha()[0], 0.0d);
        int i = 1;
        while (true) {
            int i2 = i;
            if (i2 > indexedSeq.length()) {
                return d;
            }
            AbstractInstance abstractInstance = (AbstractInstance) indexedSeq.apply(i2 - 1);
            int label = abstractInstance.label();
            Feature[][] mo396getCompVec = abstractInstance.mo396getCompVec();
            int segId = abstractInstance.segId();
            boolean z = i2 == indexedSeq.length() || segId != ((AbstractInstance) indexedSeq.apply(i2)).segId();
            computeScores(mo396getCompVec, false);
            setArrayTo(alpha()[i2], 0.0d);
            double[] dArr = beta()[i2 - 1];
            for (int i3 = 0; i3 < RichInt$.MODULE$.min$extension(Predef$.MODULE$.intWrapper(i2), this.segSize); i3++) {
                int i4 = (i2 - i3) - 1;
                double[] dArr2 = ri()[i3];
                double[][] dArr3 = mi()[i3];
                Feature[] featureArr = mo396getCompVec[i3];
                int segId2 = ((AbstractInstance) indexedSeq.apply(i4)).segId();
                int segId3 = i4 > 0 ? ((AbstractInstance) indexedSeq.apply(i4 - 1)).segId() : -1;
                setArrayTo(tmp(), -Double.MAX_VALUE);
                matrixMultLog(dArr3, alpha()[i4], tmp(), 1.0d, 0.0d, true);
                assign1(tmp(), dArr2, new DenseSemiCrf$$anonfun$forwardPass$1(this));
                if (i3 > 0) {
                    assign1(alpha()[i2], tmp(), new DenseSemiCrf$$anonfun$forwardPass$2(this));
                } else {
                    assign1(alpha()[i2], tmp(), new DenseSemiCrf$$anonfun$forwardPass$3(this));
                }
                int i5 = 0;
                while (true) {
                    int i6 = i5;
                    if (i6 < featureArr.length) {
                        Feature feature = featureArr[i6];
                        if (feature.cur() == label && ((this.segSize < 2 || (z && segId2 == segId && segId3 == segId - 1)) && (feature.prv() < 0 || (i4 > 0 && ((SeqElement) indexedSeq.apply(i4 - 1)).label() == feature.prv())))) {
                            gradient()[feature.fid()] = gradient()[feature.fid()] - feature.value();
                            d += lambdas()[feature.fid()];
                        }
                        if (feature.prv() < 0) {
                            featureExpectations()[feature.fid()] = logSumExp(featureExpectations()[feature.fid()], tmp()[feature.cur()] + dArr[feature.cur()]);
                        } else {
                            featureExpectations()[feature.fid()] = logSumExp(featureExpectations()[feature.fid()], alpha()[i4][feature.prv()] + dArr2[feature.cur()] + dArr3[feature.prv()][feature.cur()] + dArr[feature.cur()]);
                        }
                        i5 = i6 + 1;
                    }
                }
            }
            i = i2 + 1;
        }
    }

    @Override // org.mitre.jcarafe.crf.DenseCrf
    public double gradOfSeq(IndexedSeq<AbstractInstance> indexedSeq) {
        reset(true, indexedSeq.length());
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= this.nfs) {
                break;
            }
            featureExpectations()[i2] = -Double.MAX_VALUE;
            i = i2 + 1;
        }
        backwardPass(indexedSeq);
        double forwardPass = forwardPass(indexedSeq);
        double unboxToDouble = BoxesRunTime.unboxToDouble(Predef$.MODULE$.doubleArrayOps(alpha()[indexedSeq.length()]).foldLeft(BoxesRunTime.boxToDouble(-Double.MAX_VALUE), new DenseSemiCrf$$anonfun$2(this)));
        double d = forwardPass - unboxToDouble;
        int i3 = 0;
        while (true) {
            int i4 = i3;
            if (i4 >= this.nfs) {
                return d;
            }
            gradient()[i4] = gradient()[i4] + package$.MODULE$.exp(featureExpectations()[i4] - unboxToDouble);
            i3 = i4 + 1;
        }
    }

    /* JADX WARN: 'super' call moved to the top of the method (can break code semantics) */
    public DenseSemiCrf(int i, int i2, int i3, double d) {
        super(i, i2, i3, d, DenseCrf$.MODULE$.$lessinit$greater$default$5(), DenseCrf$.MODULE$.$lessinit$greater$default$6());
        this.nfs = i2;
        this.segSize = i3;
        t2_$eq((double[]) Array$.MODULE$.fill(nls(), new SemiCrf$$anonfun$1(this), ClassTag$.MODULE$.Double()));
    }
}
