/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.crf;

import com.aliasi.symbol.SymbolTable;
import com.aliasi.symbol.SymbolTableCompiler;
import com.aliasi.tag.TagLattice;
import com.aliasi.util.Strings;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class ForwardBackwardTagLattice<E>
extends TagLattice<E> {
    private final List<E> mTokens;
    private final List<String> mTags;
    private final double[][] mLogForwards;
    private final double[][] mLogBackwards;
    private final double[][][] mLogTransitions;
    private final double mLogZ;

    public ForwardBackwardTagLattice(List<E> tokens, List<String> tags, double[][] logForwards, double[][] logBackwards, double[][][] logTransitions, double logZ) {
        this(new ArrayList<E>(tokens), new ArrayList<String>(tags), logForwards, logBackwards, logTransitions, logZ, true);
        int n;
        int N = tokens.size();
        int K = tags.size();
        if (logForwards.length != N) {
            String msg = "Log forwards must be length of input. tokens.size()=" + N + " logForwards.length=" + logForwards.length;
            throw new IllegalArgumentException(msg);
        }
        if (logBackwards.length != N) {
            String msg = "Log backwards must be length of input. tokens.size()=" + N + " logBackwards.length=" + logBackwards.length;
            throw new IllegalArgumentException(msg);
        }
        if (N > 0 && logTransitions.length != N - 1) {
            String msg = "Log transitions length must be one shorter than input, or empty. Found tokens.size()=" + N + " logTransitions.length=" + logTransitions.length;
            throw new IllegalArgumentException(msg);
        }
        for (n = 0; n < N; ++n) {
            if (logForwards[n].length != K) {
                String msg = "Each log forward must be length of tags. Found tags.size()=" + K + " logForwards[" + n + "]=" + logForwards[n];
                throw new IllegalArgumentException(msg);
            }
            if (logBackwards[n].length == K) continue;
            String msg = "Each log backward must be length of tags. Found tags.size()=" + K + " logBackwards[" + n + "]=" + logBackwards[n];
            throw new IllegalArgumentException(msg);
        }
        for (n = 1; n < N; ++n) {
            if (logTransitions[n - 1].length != K) {
                String msg = "Each transition source must be length of tags. Found tags.size()=" + tags.size() + " logTransitions[" + (n - 1) + "].length=" + logTransitions[n - 1].length;
                throw new IllegalArgumentException(msg);
            }
            for (int k = 0; k < K; ++k) {
                if (logTransitions[n - 1][k].length == K) continue;
                String msg = "Each transition target must be length of tags. Found tags.size()=" + tags.size() + " logTransitions[" + (n - 1) + "][" + k + "].length=" + logTransitions[n - 1][k].length;
                throw new IllegalArgumentException(msg);
            }
        }
    }

    ForwardBackwardTagLattice(List<E> tokens, List<String> tags, double[][] logForwards, double[][] logBackwards, double[][][] logTransitions, double logZ, boolean ignore) {
        this.mTokens = tokens;
        this.mTags = tags;
        this.mLogForwards = logForwards;
        this.mLogBackwards = logBackwards;
        this.mLogTransitions = logTransitions;
        this.mLogZ = logZ;
    }

    @Override
    public List<E> tokenList() {
        return Collections.unmodifiableList(this.mTokens);
    }

    @Override
    public List<String> tagList() {
        return Collections.unmodifiableList(this.mTags);
    }

    @Override
    public String tag(int id) {
        return this.mTags.get(id);
    }

    @Override
    public int numTags() {
        return this.mTags.size();
    }

    @Override
    public E token(int n) {
        return this.mTokens.get(n);
    }

    @Override
    public int numTokens() {
        return this.mTokens.size();
    }

    @Override
    public SymbolTable tagSymbolTable() {
        return SymbolTableCompiler.asSymbolTable(this.mTags.toArray(Strings.EMPTY_STRING_ARRAY));
    }

    @Override
    public double logProbability(int token, int tag) {
        return this.mLogForwards[token][tag] + this.mLogBackwards[token][tag] - this.mLogZ;
    }

    @Override
    public double logProbability(int tokenTo, int tagFrom, int tagTo) {
        double logProb = this.mLogForwards[tokenTo - 1][tagFrom] + this.mLogBackwards[tokenTo][tagTo] + this.mLogTransitions[tokenTo - 1][tagFrom][tagTo] - this.mLogZ;
        return logProb;
    }

    @Override
    public double logProbability(int tokenFrom, int[] tags) {
        int startTag = tags[0];
        int endTag = tags[tags.length - 1];
        int tokenTo = tokenFrom + tags.length - 1;
        double logProb = this.mLogForwards[tokenFrom][startTag] + this.mLogBackwards[tokenTo][endTag] - this.mLogZ;
        for (int n = 1; n < tags.length; ++n) {
            logProb += this.mLogTransitions[tokenFrom + n - 1][tags[n - 1]][tags[n]];
        }
        return logProb;
    }

    @Override
    public double logForward(int token, int tag) {
        return this.mLogForwards[token][tag];
    }

    @Override
    public double logBackward(int token, int tag) {
        return this.mLogBackwards[token][tag];
    }

    @Override
    public double logTransition(int tokenFrom, int tagFrom, int tagTo) {
        return this.mLogTransitions[tokenFrom][tagFrom][tagTo];
    }

    @Override
    public double logZ() {
        return this.mLogZ;
    }

    public String toString() {
        int k;
        int i;
        StringBuilder sb = new StringBuilder();
        for (i = 0; i < this.mTokens.size(); ++i) {
            sb.append("token[" + i + "]=" + this.mTokens.get(i) + "\n");
        }
        sb.append("\n");
        for (int k2 = 0; k2 < this.mTags.size(); ++k2) {
            sb.append("tag[" + k2 + "]=" + this.mTags.get(k2) + "\n");
        }
        sb.append("\nlogZ=" + this.logZ() + "\n");
        sb.append("\nlogFwd[token][tag]\n");
        for (i = 0; i < this.mTokens.size(); ++i) {
            for (k = 0; k < this.mTags.size(); ++k) {
                sb.append("logFwd[" + i + "][" + k + "]=" + this.logForward(i, k) + "\n");
            }
        }
        sb.append("\nlogBk[token][tag]\n");
        for (i = 0; i < this.mTokens.size(); ++i) {
            for (k = 0; k < this.mTags.size(); ++k) {
                sb.append("logBk[" + i + "][" + k + "]=" + this.logBackward(i, k) + "\n");
            }
        }
        sb.append("\nlogTrans[tokenFrom][tagFrom][tagTo]\n");
        for (i = 1; i < this.mTokens.size(); ++i) {
            for (int kFrom = 0; kFrom < this.mTags.size(); ++kFrom) {
                for (int kTo = 0; kTo < this.mTags.size(); ++kTo) {
                    sb.append("logTrans[" + (i - 1) + "][" + kFrom + "][" + kTo + "]=" + this.logTransition(i - 1, kFrom, kTo) + "\n");
                }
            }
        }
        return sb.toString();
    }

    static void verifyNonPos(String var, double x) {
        if (Double.isNaN(x) || x > 0.0) {
            String msg = var + " must be a non-positive number." + " Found " + var + "=" + x;
            throw new IllegalArgumentException(msg);
        }
    }

    static void verifyNonPos(String var, double[] xs) {
        for (int i = 0; i < xs.length; ++i) {
            if (!Double.isNaN(xs[i]) && !(xs[i] > 0.0)) continue;
            String msg = var + " must be a non-positive number." + " Found " + var + "[" + i + "]=" + xs[i];
            throw new IllegalArgumentException(msg);
        }
    }

    static void verifyNonPos(String var, double[][] xs) {
        for (int i = 0; i < xs.length; ++i) {
            for (int j = 0; j < xs[i].length; ++j) {
                if (!Double.isNaN(xs[i][j]) && !(xs[i][j] > 0.0)) continue;
                String msg = var + " must be a non-positive number." + " Found " + var + "[" + i + "][" + j + "]=" + xs[i][j];
                throw new IllegalArgumentException(msg);
            }
        }
    }

    static void verifyNonPos(String var, double[][][] xs) {
        for (int i = 0; i < xs.length; ++i) {
            for (int j = 0; j < xs[i].length; ++j) {
                for (int k = 0; k < xs[i][j].length; ++k) {
                    if (!Double.isNaN(xs[i][j][k]) && !(xs[i][j][k] > 0.0)) continue;
                    String msg = var + " must be finite and non-positive." + " Found " + var + "[" + i + "][" + j + "][" + k + "]=" + xs[i][j][k];
                    throw new IllegalArgumentException(msg);
                }
            }
        }
    }
}

