/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.lm;

import com.aliasi.corpus.ObjectHandler;
import com.aliasi.corpus.TextHandler;
import com.aliasi.io.BitInput;
import com.aliasi.io.BitOutput;
import com.aliasi.lm.BitTrieReader;
import com.aliasi.lm.BitTrieWriter;
import com.aliasi.lm.CompiledNGramProcessLM;
import com.aliasi.lm.LanguageModel;
import com.aliasi.lm.Node;
import com.aliasi.lm.TrieCharSeqCounter;
import com.aliasi.stats.Model;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Arrays;
import com.aliasi.util.Math;
import com.aliasi.util.Strings;
import java.io.Externalizable;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.LinkedList;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class NGramProcessLM
implements TextHandler,
Model<CharSequence>,
LanguageModel.Process,
LanguageModel.Conditional,
LanguageModel.Dynamic,
ObjectHandler<CharSequence>,
Serializable {
    static final long serialVersionUID = -2865886217715962249L;
    private final TrieCharSeqCounter mTrieCharSeqCounter;
    private final int mMaxNGram;
    private double mLambdaFactor;
    private int mNumChars;
    private double mUniformEstimate;
    private double mLog2UniformEstimate;

    public NGramProcessLM(int maxNGram) {
        this(maxNGram, 65535);
    }

    public NGramProcessLM(int maxNGram, int numChars) {
        this(maxNGram, numChars, maxNGram);
    }

    public NGramProcessLM(int maxNGram, int numChars, double lambdaFactor) {
        this(numChars, lambdaFactor, new TrieCharSeqCounter(maxNGram));
    }

    public NGramProcessLM(int numChars, double lambdaFactor, TrieCharSeqCounter counter) {
        this.mMaxNGram = counter.mMaxLength;
        this.setLambdaFactor(lambdaFactor);
        this.setNumChars(numChars);
        this.mTrieCharSeqCounter = counter;
    }

    public void writeTo(OutputStream out) throws IOException {
        BitOutput bitOut = new BitOutput(out);
        this.writeTo(bitOut);
        bitOut.flush();
    }

    void writeTo(BitOutput bitOut) throws IOException {
        bitOut.writeDelta(this.mMaxNGram);
        bitOut.writeDelta(this.mNumChars);
        bitOut.writeDelta((int)(this.mLambdaFactor * 1000000.0));
        BitTrieWriter trieWriter = new BitTrieWriter(bitOut);
        TrieCharSeqCounter.writeCounter(this.mTrieCharSeqCounter, trieWriter, this.mMaxNGram);
    }

    public static NGramProcessLM readFrom(InputStream in) throws IOException {
        BitInput bitIn = new BitInput(in);
        return NGramProcessLM.readFrom(bitIn);
    }

    static NGramProcessLM readFrom(BitInput bitIn) throws IOException {
        int maxNGram = (int)bitIn.readDelta();
        int numChars = (int)bitIn.readDelta();
        double lambdaFactor = (double)bitIn.readDelta() / 1000000.0;
        BitTrieReader trieReader = new BitTrieReader(bitIn);
        TrieCharSeqCounter counter = TrieCharSeqCounter.readCounter(trieReader, maxNGram);
        return new NGramProcessLM(numChars, lambdaFactor, counter);
    }

    @Override
    public double log2Prob(CharSequence cSeq) {
        return this.log2Estimate(cSeq);
    }

    @Override
    public double prob(CharSequence cSeq) {
        return java.lang.Math.pow(2.0, this.log2Estimate(cSeq));
    }

    @Override
    public final double log2Estimate(CharSequence cSeq) {
        char[] cs = Strings.toCharArray(cSeq);
        return this.log2Estimate(cs, 0, cs.length);
    }

    @Override
    public final double log2Estimate(char[] cs, int start, int end) {
        Strings.checkArgsStartEnd(cs, start, end);
        double sum = 0.0;
        for (int i = start + 1; i <= end; ++i) {
            sum += this.log2ConditionalEstimate(cs, start, i);
        }
        return sum;
    }

    @Override
    public void train(CharSequence cSeq) {
        this.train(cSeq, 1);
    }

    @Override
    public void train(CharSequence cSeq, int incr) {
        char[] cs = Strings.toCharArray(cSeq);
        this.train(cs, 0, cs.length, incr);
    }

    @Override
    public void train(char[] cs, int start, int end) {
        this.train(cs, start, end, 1);
    }

    @Override
    public void train(char[] cs, int start, int end, int incr) {
        Strings.checkArgsStartEnd(cs, start, end);
        this.mTrieCharSeqCounter.incrementSubstrings(cs, start, end, incr);
    }

    @Override
    @Deprecated
    public void handle(char[] cs, int start, int length) {
        this.train(cs, start, start + length);
    }

    @Override
    public void handle(CharSequence cSeq) {
        this.train(cSeq);
    }

    public void trainConditional(char[] cs, int start, int end, int condEnd) {
        Strings.checkArgsStartEnd(cs, start, end);
        Strings.checkArgsStartEnd(cs, start, condEnd);
        if (condEnd > end) {
            String msg = "Conditional end must be < end. Found condEnd=" + condEnd + " end=" + end;
            throw new IllegalArgumentException(msg);
        }
        if (condEnd == end) {
            return;
        }
        this.mTrieCharSeqCounter.incrementSubstrings(cs, start, end);
        this.mTrieCharSeqCounter.decrementSubstrings(cs, start, condEnd);
    }

    @Override
    public char[] observedCharacters() {
        return this.mTrieCharSeqCounter.observedCharacters();
    }

    @Override
    public void compileTo(ObjectOutput objOut) throws IOException {
        objOut.writeObject(new Externalizer(this));
    }

    @Override
    public double log2ConditionalEstimate(CharSequence cSeq) {
        return this.log2ConditionalEstimate(cSeq, this.mMaxNGram, this.mLambdaFactor);
    }

    @Override
    public double log2ConditionalEstimate(char[] cs, int start, int end) {
        return this.log2ConditionalEstimate(cs, start, end, this.mMaxNGram, this.mLambdaFactor);
    }

    public TrieCharSeqCounter substringCounter() {
        return this.mTrieCharSeqCounter;
    }

    public int maxNGram() {
        return this.mMaxNGram;
    }

    public double log2ConditionalEstimate(CharSequence cSeq, int maxNGram, double lambdaFactor) {
        char[] cs = Strings.toCharArray(cSeq);
        return this.log2ConditionalEstimate(cs, 0, cs.length, maxNGram, lambdaFactor);
    }

    public double log2ConditionalEstimate(char[] cs, int start, int end, int maxNGram, double lambdaFactor) {
        long contextCount;
        if (end <= start) {
            String msg = "Conditional estimates require at least one character.";
            throw new IllegalArgumentException(msg);
        }
        Strings.checkArgsStartEnd(cs, start, end);
        NGramProcessLM.checkMaxNGram(maxNGram);
        NGramProcessLM.checkLambdaFactor(lambdaFactor);
        int maxUsableNGram = java.lang.Math.min(maxNGram, this.mMaxNGram);
        if (start == end) {
            return 0.0;
        }
        double currentEstimate = this.mUniformEstimate;
        int contextEnd = end - 1;
        int longestContextStart = java.lang.Math.max(start, end - maxUsableNGram);
        for (int currentContextStart = contextEnd; currentContextStart >= longestContextStart && (contextCount = this.mTrieCharSeqCounter.extensionCount(cs, currentContextStart, contextEnd)) != 0L; --currentContextStart) {
            long outcomeCount = this.mTrieCharSeqCounter.count(cs, currentContextStart, end);
            double lambda = this.lambda(cs, currentContextStart, contextEnd, lambdaFactor);
            currentEstimate = lambda * ((double)outcomeCount / (double)contextCount) + (1.0 - lambda) * currentEstimate;
        }
        return Math.log2(currentEstimate);
    }

    double lambda(char[] cs, int start, int end) {
        return this.lambda(cs, start, end, this.getLambdaFactor());
    }

    double lambda(char[] cs, int start, int end, double lambdaFactor) {
        NGramProcessLM.checkLambdaFactor(lambdaFactor);
        Strings.checkArgsStartEnd(cs, start, end);
        double count = this.mTrieCharSeqCounter.extensionCount(cs, start, end);
        if (count <= 0.0) {
            return 0.0;
        }
        double numOutcomes = this.mTrieCharSeqCounter.numCharactersFollowing(cs, start, end);
        return this.lambda(count, numOutcomes, lambdaFactor);
    }

    public double getLambdaFactor() {
        return this.mLambdaFactor;
    }

    public final void setLambdaFactor(double lambdaFactor) {
        NGramProcessLM.checkLambdaFactor(lambdaFactor);
        this.mLambdaFactor = lambdaFactor;
    }

    public final void setNumChars(int numChars) {
        NGramProcessLM.checkNumChars(numChars);
        this.mNumChars = numChars;
        this.mUniformEstimate = 1.0 / (double)this.mNumChars;
        this.mLog2UniformEstimate = Math.log2(this.mUniformEstimate);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        this.toStringBuilder(sb);
        return sb.toString();
    }

    void toStringBuilder(StringBuilder sb) {
        sb.append("Max NGram=" + this.mMaxNGram + " ");
        sb.append("Num characters=" + this.mNumChars + "\n");
        sb.append("Trie of counts=\n");
        this.mTrieCharSeqCounter.toStringBuilder(sb);
    }

    void decrementUnigram(char c) {
        this.decrementUnigram(c, 1);
    }

    void decrementUnigram(char c, int count) {
        this.mTrieCharSeqCounter.decrementUnigram(c, count);
    }

    private double lambda(double count, double numOutcomes, double lambdaFactor) {
        return count / (count + lambdaFactor * numOutcomes);
    }

    private double lambda(Node node) {
        double count = node.contextCount(Strings.EMPTY_CHAR_ARRAY, 0, 0);
        double numOutcomes = node.numOutcomes(Strings.EMPTY_CHAR_ARRAY, 0, 0);
        return this.lambda(count, numOutcomes, this.mLambdaFactor);
    }

    private int lastInternalNodeIndex() {
        int last = 1;
        LinkedList<Node> queue = new LinkedList<Node>();
        queue.add(this.mTrieCharSeqCounter.mRootNode);
        int i = 1;
        while (!queue.isEmpty()) {
            Node node = (Node)queue.removeFirst();
            if (node.numOutcomes(Strings.EMPTY_CHAR_ARRAY, 0, 0) > 0) {
                last = i;
            }
            node.addDaughters(queue);
            ++i;
        }
        return last - 1;
    }

    private Object writeReplace() {
        return new Serializer(this);
    }

    static void checkLambdaFactor(double lambdaFactor) {
        if (lambdaFactor < 0.0 || Double.isInfinite(lambdaFactor) || Double.isNaN(lambdaFactor)) {
            String msg = "Lambda factor must be ordinary non-negative double. Found lambdaFactor=" + lambdaFactor;
            throw new IllegalArgumentException(msg);
        }
    }

    static void checkMaxNGram(int maxNGram) {
        if (maxNGram < 1) {
            String msg = "Maximum n-gram must be greater than zero. Found max n-gram=" + maxNGram;
            throw new IllegalArgumentException(msg);
        }
    }

    private static void checkNumChars(int numChars) {
        if (numChars < 0 || numChars > 65535) {
            String msg = "Number of characters must be > 0 and  must be less than Character.MAX_VALUE Found numChars=" + numChars;
            throw new IllegalArgumentException(msg);
        }
    }

    static class Externalizer
    extends AbstractExternalizable {
        static final long serialVersionUID = -3623859317152451545L;
        final NGramProcessLM mLM;

        public Externalizer() {
            this(null);
        }

        public Externalizer(NGramProcessLM lm) {
            this.mLM = lm;
        }

        public Object read(ObjectInput in) throws IOException {
            return new CompiledNGramProcessLM(in);
        }

        public void writeExternal(ObjectOutput dataOut) throws IOException {
            dataOut.writeInt(this.mLM.mMaxNGram);
            dataOut.writeFloat((float)this.mLM.mLog2UniformEstimate);
            long numNodes = this.mLM.mTrieCharSeqCounter.uniqueSequenceCount();
            if (numNodes > Integer.MAX_VALUE) {
                String msg = "Maximum number of compiled nodes is Integer.MAX_VALUE = 2147483647 Found number of nodes=" + numNodes;
                throw new IllegalArgumentException(msg);
            }
            dataOut.writeInt((int)numNodes);
            int lastInternalNodeIndex = this.mLM.lastInternalNodeIndex();
            dataOut.writeInt(lastInternalNodeIndex);
            dataOut.writeChar(65535);
            dataOut.writeFloat((float)this.mLM.mLog2UniformEstimate);
            double oneMinusLambda = 1.0 - this.mLM.lambda(((NGramProcessLM)this.mLM).mTrieCharSeqCounter.mRootNode);
            float log2OneMinusLambda = Double.isNaN(oneMinusLambda) ? 0.0f : (float)Math.log2(oneMinusLambda);
            dataOut.writeFloat(log2OneMinusLambda);
            dataOut.writeInt(1);
            char[] cs = this.mLM.mTrieCharSeqCounter.observedCharacters();
            LinkedList<char[]> queue = new LinkedList<char[]>();
            for (int i = 0; i < cs.length; ++i) {
                queue.add(new char[]{cs[i]});
            }
            int index = 1;
            while (!queue.isEmpty()) {
                char[] nGram = (char[])queue.removeFirst();
                char c = nGram[nGram.length - 1];
                dataOut.writeChar(c);
                float logConditionalEstimate = (float)this.mLM.log2ConditionalEstimate(nGram, 0, nGram.length);
                dataOut.writeFloat(logConditionalEstimate);
                if (index <= lastInternalNodeIndex) {
                    double oneMinusLambda2 = 1.0 - this.mLM.lambda(nGram, 0, nGram.length);
                    float log2OneMinusLambda2 = (float)Math.log2(oneMinusLambda2);
                    dataOut.writeFloat(log2OneMinusLambda2);
                    int firstChildIndex = index + queue.size() + 1;
                    dataOut.writeInt(firstChildIndex);
                }
                char[] cs2 = this.mLM.mTrieCharSeqCounter.charactersFollowing(nGram, 0, nGram.length);
                for (int i = 0; i < cs2.length; ++i) {
                    queue.add(Arrays.concatenate(nGram, cs2[i]));
                }
                ++index;
            }
        }
    }

    static class Serializer
    implements Externalizable {
        static final long serialVersionUID = -7101238964823109652L;
        NGramProcessLM mLM;

        public Serializer() {
        }

        public Serializer(NGramProcessLM lm) {
            this.mLM = lm;
        }

        public void writeExternal(ObjectOutput out) throws IOException {
            this.mLM.writeTo((OutputStream)((Object)out));
        }

        public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
            this.mLM = NGramProcessLM.readFrom((InputStream)((Object)in));
        }

        public Object readResolve() {
            return this.mLM;
        }
    }
}

