/*
 * Decompiled with CFR 0.152.
 */
package marytts.htsengine;

import java.util.Arrays;
import marytts.htsengine.HMMData;
import marytts.htsengine.HTSParameterGeneration;
import marytts.util.MaryUtils;
import org.apache.log4j.Logger;

public class HTSPStream {
    public static final int WLEFT = 0;
    public static final int WRIGHT = 1;
    public static final int NUM = 3;
    private static final int WIDTH = 3;
    public final HMMData.FeatureType feaType;
    private final int vSize;
    private final int order;
    private int nT;
    private double[][] par;
    private double[][] mseq;
    private double[][] ivseq;
    private double[] g;
    private double[][] wuw;
    private double[] wum;
    static final int[] leftWidths;
    static final int[] rightWidths;
    static final double[] xcoefs;
    private double mean;
    private double var;
    private final int maxGVIter;
    private static final double GVepsilon = 1.0E-4;
    private static final double minEucNorm = 0.01;
    private static final double stepInit = 0.1;
    private static final double stepDec = 0.5;
    private static final double stepInc = 1.2;
    private static final double w1 = 1.0;
    private static final double w2 = 1.0;
    private static final double lzero = -1.0E10;
    private double norm = 0.0;
    private double GVobj = 0.0;
    private double HMMobj = 0.0;
    private double[] gvmean;
    private double[] gvcovInv;
    private boolean[] gvSwitch;
    private int gvLength;
    private Logger logger = MaryUtils.getLogger("PStream");

    static {
        int[] nArray = new int[3];
        nArray[1] = -1;
        nArray[2] = -1;
        leftWidths = nArray;
        int[] nArray2 = new int[3];
        nArray2[1] = 1;
        nArray2[2] = 1;
        rightWidths = nArray2;
        xcoefs = new double[]{0.0, 1.0, 0.0, -0.5, 0.0, 0.5, 1.0, -2.0, 1.0};
    }

    public int getDWLeftBoundary(int i) {
        return leftWidths[i];
    }

    public int getDWRightBoundary(int i) {
        return rightWidths[i];
    }

    public HTSPStream(int vector_size, int utt_length, HMMData.FeatureType fea_type, int maxIterationsGV) throws Exception {
        this.feaType = fea_type;
        this.vSize = vector_size;
        this.order = vector_size / 3;
        this.nT = utt_length;
        this.maxGVIter = maxIterationsGV;
        this.par = new double[this.nT][this.order];
        this.mseq = new double[this.nT][this.vSize];
        this.ivseq = new double[this.nT][this.vSize];
        this.g = new double[this.nT];
        this.wuw = new double[this.nT][3];
        this.wum = new double[this.nT];
        this.gvSwitch = new boolean[this.nT];
        int i = 0;
        while (i < this.nT) {
            this.gvSwitch[i] = true;
            ++i;
        }
        this.gvLength = this.nT;
    }

    public int getVsize() {
        return this.vSize;
    }

    public int getOrder() {
        return this.order;
    }

    public void setPar(int i, int j, double val) {
        this.par[i][j] = val;
    }

    public double getPar(int i, int j) {
        return this.par[i][j];
    }

    public double[] getParVec(int i) {
        return Arrays.copyOf(this.par[i], this.par[i].length);
    }

    public int getT() {
        return this.nT;
    }

    public void setMseq(int i, int j, double val) {
        this.mseq[i][j] = val;
    }

    public void setMseq(int i, double[] vec) {
        this.mseq[i] = vec;
    }

    public void setVseq(int i, double[] vec) {
        assert (vec.length == this.ivseq[i].length);
        int j = 0;
        while (j < this.ivseq[i].length) {
            this.ivseq[i][j] = HTSParameterGeneration.finv(vec[j]);
            ++j;
        }
    }

    public void setIvseq(int i, int j, double val) {
        this.ivseq[i][j] = val;
    }

    public void setGvMeanVar(double[] mean, double[] ivar) {
        this.gvmean = mean;
        this.gvcovInv = ivar;
    }

    public void setGvSwitch(int i, boolean bv) {
        if (!bv) {
            --this.gvLength;
        }
        this.gvSwitch[i] = bv;
    }

    public void fixDynFeatOnBoundaries() {
        int k = 1;
        while (k < this.vSize) {
            this.setIvseq(0, k, 0.0);
            this.setIvseq(this.nT - 1, k, 0.0);
            ++k;
        }
    }

    private void printWUW(int t) {
        int i = 0;
        while (i < 3) {
            System.out.print("WUW[" + t + "][" + i + "]=" + this.wuw[t][i] + "  ");
            ++i;
        }
        System.out.println("");
    }

    public void mlpg(HMMData htsData) {
        this.mlpg(htsData, htsData.getUseGV());
    }

    public void mlpg(HMMData htsData, boolean useGV) {
        if (htsData.getUseContextDependentGV()) {
            this.logger.info("Context-dependent global variance optimization: gvLength = " + this.gvLength);
        } else {
            this.logger.info("Global variance optimization");
        }
        int m = 0;
        while (m < this.order) {
            this.calcWUWandWUM(m);
            double[][] mywuw = new double[this.nT][];
            int x = 0;
            while (x < this.wuw.length) {
                mywuw[x] = Arrays.copyOf(this.wuw[x], this.wuw[x].length);
                ++x;
            }
            double[] mywum = Arrays.copyOf(this.wum, this.wum.length);
            HTSPStream.ldlFactorization(mywuw);
            this.forwardSubstitution(mywum, mywuw);
            this.backwardSubstitution(m, mywuw);
            if (useGV && this.gvLength > 0) {
                if (htsData.getGvMethodGradient()) {
                    this.gvParmGenGradient(m, false);
                } else {
                    this.gvParmGenDerivative(m, false);
                }
            }
            ++m;
        }
    }

    private void calcWUWandWUM(int m) {
        Arrays.fill(this.wum, 0, this.nT, 0.0);
        int t = 0;
        while (t < this.nT) {
            Arrays.fill(this.wuw[t], 0.0);
            int i = 0;
            while (i < 3) {
                int dwWidth_iright = rightWidths[i];
                int iorder = i * this.order + m;
                int j = leftWidths[i];
                while (j <= dwWidth_iright) {
                    double dwCoef_ij;
                    if (t + j >= 0 && t + j < this.nT && (dwCoef_ij = xcoefs[1 + i * 3 - j]) != 0.0) {
                        double WU = dwCoef_ij * this.ivseq[t + j][iorder];
                        int n = t;
                        this.wum[n] = this.wum[n] + WU * this.mseq[t + j][iorder];
                        int k = 0;
                        while (k < 3 && t + k < this.nT) {
                            double dwCoef_ikj;
                            if (k - j <= dwWidth_iright && (dwCoef_ikj = xcoefs[1 + i * 3 + k - j]) != 0.0) {
                                double[] dArray = this.wuw[t];
                                int n2 = k;
                                dArray[n2] = dArray[n2] + WU * dwCoef_ikj;
                            }
                            ++k;
                        }
                    }
                    ++j;
                }
                ++i;
            }
            ++t;
        }
    }

    private static void ldlFactorization(double[][] mywuw) {
        int t = 0;
        while (t < mywuw.length) {
            int i = 1;
            while (i < 3 && t - i >= 0) {
                double[] dArray = mywuw[t];
                dArray[0] = dArray[0] - mywuw[t - i][i] * mywuw[t - i][i] * mywuw[t - i][0];
                ++i;
            }
            i = 2;
            while (i <= 3) {
                int j = 1;
                while (i + j <= 3 && t - j >= 0) {
                    double[] dArray = mywuw[t];
                    int n = i - 1;
                    dArray[n] = dArray[n] - mywuw[t - j][j] * mywuw[t - j][i + j - 1] * mywuw[t - j][0];
                    ++j;
                }
                double[] dArray = mywuw[t];
                int n = i - 1;
                dArray[n] = dArray[n] / mywuw[t][0];
                ++i;
            }
            ++t;
        }
    }

    private void forwardSubstitution(double[] mywum, double[][] mywuw) {
        System.arraycopy(mywum, 0, this.g, 0, mywum.length);
        int t = 0;
        while (t < this.nT) {
            int i = 1;
            while (i < 3 && t - i >= 0) {
                int n = t;
                this.g[n] = this.g[n] - mywuw[t - i][i] * this.g[t - i];
                ++i;
            }
            ++t;
        }
    }

    private void backwardSubstitution(int m, double[][] mywuw) {
        int t = this.nT - 1;
        while (t >= 0) {
            this.par[t][m] = this.g[t] / mywuw[t][0];
            int i = 1;
            while (i < 3 && t + i < this.nT) {
                double[] dArray = this.par[t];
                int n = m;
                dArray[n] = dArray[n] - mywuw[t][i] * this.par[t + i][m];
                ++i;
            }
            --t;
        }
    }

    private void gvParmGenDerivative(int m, boolean debug) {
        double step2 = 0.1;
        double prev = 1.0E10;
        double obj = 0.0;
        double[] cfr_ignored_0 = new double[this.nT];
        double[] par_ori = new double[this.nT];
        this.mean = 0.0;
        this.var = 0.0;
        int t = 0;
        while (t < this.nT) {
            this.g[t] = 0.0;
            par_ori[t] = this.par[t][m];
            ++t;
        }
        this.convGV(m);
        this.calcWUWandWUM(m);
        int iter = 1;
        while (iter <= this.maxGVIter) {
            obj = this.calcDerivative(m);
            if (obj > prev) {
                step2 *= 0.5;
            }
            if (obj < prev) {
                step2 *= 1.2;
            }
            t = 0;
            while (t < this.nT) {
                double[] dArray = this.par[t];
                int n = m;
                dArray[n] = dArray[n] + step2 * this.g[t];
                ++t;
            }
            prev = obj;
            ++iter;
        }
        this.logger.info("Derivative GV optimization for feature: (" + m + ")  number of iterations=" + (iter - 1));
    }

    private void gvParmGenGradient(int m, boolean debug) {
        double step2 = 0.1;
        double obj = 0.0;
        double prev = 0.0;
        double[] diag = new double[this.nT];
        double[] par_ori = new double[this.nT];
        this.mean = 0.0;
        this.var = 0.0;
        int numDown = 0;
        int totalNumIter = 0;
        int t = 0;
        while (t < this.nT) {
            this.g[t] = 0.0;
            par_ori[t] = this.par[t][m];
            ++t;
        }
        this.convGV(m);
        this.calcWUWandWUM(m);
        int iter = 1;
        while (iter <= this.maxGVIter) {
            block16: {
                block15: {
                    block14: {
                        obj = this.calcGradient(m);
                        if (iter <= 1) break block14;
                        if (obj > prev) {
                            step2 *= 1.2;
                            numDown = 0;
                        }
                        if (!(obj < prev)) break block15;
                        t = 0;
                        while (t < this.nT) {
                            double[] dArray = this.par[t];
                            int n = m;
                            dArray[n] = dArray[n] - step2 * diag[t];
                            ++t;
                        }
                        step2 *= 0.5;
                        t = 0;
                        while (t < this.nT) {
                            double[] dArray = this.par[t];
                            int n = m;
                            dArray[n] = dArray[n] + step2 * diag[t];
                            ++t;
                        }
                        --iter;
                        if (++numDown >= 100) {
                            this.logger.info("  ***Convergence problems....optimization stopped. Number of iterations: " + iter);
                            break;
                        }
                        break block16;
                    }
                    if (debug) {
                        this.logger.info("  First iteration:  GVobj=" + obj + " (HMMobj=" + this.HMMobj + "  GVobj=" + this.GVobj + ")");
                    }
                }
                if (this.norm < 0.01 || iter > 1 && Math.abs(obj - prev) < 1.0E-4) {
                    if (debug) {
                        this.logger.info("  Number of iterations: [   " + iter + "   ] GVobj=" + obj + " (HMMobj=" + this.HMMobj + "  GVobj=" + this.GVobj + ")");
                    }
                    ++totalNumIter;
                    if (m == 0) {
                    }
                    if (!debug) break;
                    if (iter > 1) {
                        this.logger.info("  Converged (norm=" + this.norm + ", change=" + Math.abs(obj - prev) + ")");
                        break;
                    }
                    this.logger.info("  Converged (norm=" + this.norm + ")");
                    break;
                }
                t = 0;
                while (t < this.nT) {
                    double[] dArray = this.par[t];
                    int n = m;
                    dArray[n] = dArray[n] + step2 * this.g[t];
                    diag[t] = this.g[t];
                    ++t;
                }
                prev = obj;
            }
            ++iter;
        }
        if (iter > this.maxGVIter) {
            this.logger.info("   optimization stopped by reaching max number of iterations (no global variance applied)");
            t = 0;
            while (t < this.nT) {
                this.par[t][m] = par_ori[t];
                ++t;
            }
        }
        totalNumIter = iter;
        this.logger.info("Gradient GV optimization for feature: (" + m + ")  number of iterations=" + totalNumIter);
    }

    private double calcGradient(int m) {
        double w = 1.0 / (double)(3 * this.nT);
        this.calcGV(m);
        this.GVobj = -0.5 * (this.var - this.gvmean[m]) * this.gvcovInv[m] * (this.var - this.gvmean[m]);
        double vd = this.gvcovInv[m] * (this.var - this.gvmean[m]);
        int t = 0;
        while (t < this.nT) {
            this.g[t] = this.wuw[t][0] * this.par[t][m];
            int i = 2;
            while (i <= 3) {
                if (t + i - 1 < this.nT) {
                    int n = t;
                    this.g[n] = this.g[n] + this.wuw[t][i - 1] * this.par[t + i - 1][m];
                }
                if (t - i + 1 >= 0) {
                    int n = t;
                    this.g[n] = this.g[n] + this.wuw[t - i + 1][i - 1] * this.par[t - i + 1][m];
                }
                ++i;
            }
            ++t;
        }
        t = 0;
        this.HMMobj = 0.0;
        this.norm = 0.0;
        while (t < this.nT) {
            this.HMMobj += -0.5 * w * this.par[t][m] * (this.g[t] - 2.0 * this.wum[t]);
            double h = (double)(this.nT - 1) * vd + 2.0 * this.gvcovInv[m] * (this.par[t][m] - this.mean) * (this.par[t][m] - this.mean);
            h = -1.0 * w * this.wuw[t][0] - 2.0 / (double)(this.nT * this.nT) * h;
            h = -1.0 / h;
            if (this.gvSwitch[t]) {
                double aux = (this.par[t][m] - this.mean) * vd;
                this.g[t] = h * (1.0 * w * (-this.g[t] + this.wum[t]) + -2.0 / (double)this.nT * aux);
            } else {
                this.g[t] = h * (1.0 * w * (-this.g[t] + this.wum[t]));
            }
            this.norm += this.g[t] * this.g[t];
            ++t;
        }
        this.norm = Math.sqrt(this.norm);
        return this.HMMobj + this.GVobj;
    }

    private double calcDerivative(int m) {
        double w = 1.0 / (double)(3 * this.nT);
        this.calcGV(m);
        this.GVobj = -0.5 * this.var * this.gvcovInv[m] * (this.var - 2.0 * this.gvmean[m]);
        double vd = -2.0 * this.gvcovInv[m] * (this.var - this.gvmean[m]) / (double)this.nT;
        int t = 0;
        while (t < this.nT) {
            this.g[t] = this.wuw[t][0] * this.par[t][m];
            int i = 2;
            while (i <= 3) {
                if (t + i - 1 < this.nT) {
                    int n = t;
                    this.g[n] = this.g[n] + this.wuw[t][i - 1] * this.par[t + i - 1][m];
                }
                if (t - i + 1 >= 0) {
                    int n = t;
                    this.g[n] = this.g[n] + this.wuw[t - i + 1][i - 1] * this.par[t - i + 1][m];
                }
                ++i;
            }
            ++t;
        }
        t = 0;
        this.HMMobj = 0.0;
        while (t < this.nT) {
            this.HMMobj += 1.0 * w * this.par[t][m] * (this.wum[t] - 0.5 * this.g[t]);
            double h = -1.0 * w * this.wuw[t][0] - 2.0 / (double)(this.nT * this.nT) * ((double)(this.nT - 1) * this.gvcovInv[m] * (this.var - this.gvmean[m]) + 2.0 * this.gvcovInv[m] * (this.par[t][m] - this.mean) * (this.par[t][m] - this.mean));
            this.g[t] = this.gvSwitch[t] ? 1.0 / h * (1.0 * w * (-this.g[t] + this.wum[t]) + 1.0 * vd * (this.par[t][m] - this.mean)) : 1.0 / h * (1.0 * w * (-this.g[t] + this.wum[t]));
            ++t;
        }
        return -(this.HMMobj + this.GVobj);
    }

    private void convGV(int m) {
        this.calcGV(m);
        double ratio = Math.sqrt(this.gvmean[m] / this.var);
        int t = 0;
        while (t < this.nT) {
            if (this.gvSwitch[t]) {
                this.par[t][m] = ratio * (this.par[t][m] - this.mean) + this.mean;
            }
            ++t;
        }
    }

    private void calcGV(int m) {
        this.mean = 0.0;
        this.var = 0.0;
        int t = 0;
        while (t < this.nT) {
            if (this.gvSwitch[t]) {
                this.mean += this.par[t][m];
            }
            ++t;
        }
        this.mean /= (double)this.gvLength;
        t = 0;
        while (t < this.nT) {
            if (this.gvSwitch[t]) {
                this.var += (this.par[t][m] - this.mean) * (this.par[t][m] - this.mean);
            }
            ++t;
        }
        this.var /= (double)this.gvLength;
    }
}

