/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.hmm;

import com.aliasi.hmm.AbstractHmmEstimator;
import com.aliasi.hmm.CompiledHmmCharLm;
import com.aliasi.lm.NGramBoundaryLM;
import com.aliasi.symbol.MapSymbolTable;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Exceptions;
import com.aliasi.util.ObjectToCounterMap;
import com.aliasi.util.Tuple;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

public class HmmCharLmEstimator
extends AbstractHmmEstimator {
    private final MapSymbolTable mStateMapSymbolTable;
    private final ObjectToCounterMap<String> mStateExtensionCounter = new ObjectToCounterMap();
    private final ObjectToCounterMap<Tuple<String>> mStatePairCounter = new ObjectToCounterMap();
    private final Map<String, NGramBoundaryLM> mStateToLm = new HashMap<String, NGramBoundaryLM>();
    private final double mCharLmInterpolation;
    private final int mCharLmMaxNGram;
    private final int mMaxCharacters;
    private int mNumStarts = 0;
    private final ObjectToCounterMap<String> mStartCounter = new ObjectToCounterMap();
    private int mNumEnds = 0;
    private final ObjectToCounterMap<String> mEndCounter = new ObjectToCounterMap();
    private final boolean mSmootheStates;
    final Set<String> mStateSet = new HashSet<String>();

    public HmmCharLmEstimator() {
        this(6, 65534, 6.0);
    }

    public HmmCharLmEstimator(int charLmMaxNGram, int maxCharacters, double charLmInterpolation) {
        this(charLmMaxNGram, maxCharacters, charLmInterpolation, false);
    }

    public HmmCharLmEstimator(int charLmMaxNGram, int maxCharacters, double charLmInterpolation, boolean smootheStates) {
        super(new MapSymbolTable());
        this.mSmootheStates = smootheStates;
        if (charLmMaxNGram < 1) {
            String msg = "Max n-gram must be greater than 0. Found charLmMaxNGram=" + charLmMaxNGram;
            throw new IllegalArgumentException(msg);
        }
        if (maxCharacters < 1 || maxCharacters > 65534) {
            String msg = "Require between 1 and 65534 max characters. Found maxCharacters=" + maxCharacters;
            throw new IllegalArgumentException(msg);
        }
        if (charLmInterpolation < 0.0) {
            String msg = "Char interpolation param must be between  0.0 and 1.0 inclusive. Found charLmInterpolation=" + charLmInterpolation;
            throw new IllegalArgumentException(msg);
        }
        this.mStateMapSymbolTable = (MapSymbolTable)this.stateSymbolTable();
        this.mCharLmInterpolation = charLmInterpolation;
        this.mCharLmMaxNGram = charLmMaxNGram;
        this.mMaxCharacters = maxCharacters;
    }

    void addStateSmoothe(String state) {
        if (!this.mStateSet.add(state)) {
            return;
        }
        this.mStateMapSymbolTable.getOrAddSymbol(state);
        if (!this.mSmootheStates) {
            return;
        }
        this.trainStart(state);
        this.trainEnd(state);
        for (String state2 : this.mStateSet) {
            this.trainTransit(state, state2);
            if (state.equals(state2)) continue;
            this.trainTransit(state2, state);
        }
    }

    public void trainStart(String state) {
        if (state == null) {
            return;
        }
        this.addStateSmoothe(state);
        ++this.mNumStarts;
        this.mStartCounter.increment(state);
    }

    static void verifyNonNegativeCount(int count) {
        if (count >= 0) {
            return;
        }
        String msg = "Counts must be positve. Found count=" + count;
        throw new IllegalArgumentException(msg);
    }

    public void trainEnd(String state) {
        if (state == null) {
            return;
        }
        this.addStateSmoothe(state);
        this.mStateExtensionCounter.increment(state);
        ++this.mNumEnds;
        this.mEndCounter.increment(state);
    }

    public void trainEmit(String state, CharSequence emission) {
        if (state == null) {
            return;
        }
        if (emission == null) {
            return;
        }
        this.addStateSmoothe(state);
        this.emissionLm(state).train(emission);
    }

    public void trainTransit(String sourceState, String targetState) {
        if (sourceState == null || targetState == null) {
            return;
        }
        this.addStateSmoothe(sourceState);
        this.addStateSmoothe(targetState);
        this.mStateExtensionCounter.increment(sourceState);
        this.mStatePairCounter.increment(Tuple.create(sourceState, targetState));
    }

    public double startProb(String state) {
        double count = this.mStartCounter.getCount(state);
        double total = this.mNumStarts;
        return count / total;
    }

    public double endProb(String state) {
        double count = this.mEndCounter.getCount(state);
        double total = this.mNumEnds;
        return count / total;
    }

    public double transitProb(String source, String target) {
        double extCount = this.mStateExtensionCounter.getCount(source);
        double pairCount = this.mStatePairCounter.getCount(Tuple.create(source, target));
        return pairCount / extCount;
    }

    public double emitProb(String state, CharSequence emission) {
        return Math.pow(2.0, this.emitLog2Prob(state, emission));
    }

    public double emitLog2Prob(String state, CharSequence emission) {
        return this.emissionLm(state).log2Estimate(emission);
    }

    public NGramBoundaryLM emissionLm(String state) {
        NGramBoundaryLM lm = this.mStateToLm.get(state);
        if (lm == null) {
            lm = new NGramBoundaryLM(this.mCharLmMaxNGram, this.mMaxCharacters, this.mCharLmInterpolation, '\uffff');
            this.mStateToLm.put(state, lm);
        }
        return lm;
    }

    public void compileTo(ObjectOutput objOut) throws IOException {
        objOut.writeObject(new Externalizer(this));
    }

    static class Externalizer
    extends AbstractExternalizable {
        private static final long serialVersionUID = 8463739963673120677L;
        final HmmCharLmEstimator mEstimator;

        public Externalizer() {
            this(null);
        }

        public Externalizer(HmmCharLmEstimator handler) {
            this.mEstimator = handler;
        }

        public Object read(ObjectInput in) throws IOException {
            try {
                return new CompiledHmmCharLm(in);
            }
            catch (ClassNotFoundException e) {
                throw Exceptions.toIO("HmmCharLmEstimator.compileTo()", e);
            }
        }

        public void writeExternal(ObjectOutput objOut) throws IOException {
            int i;
            objOut.writeObject(this.mEstimator.mStateMapSymbolTable);
            int numStates = this.mEstimator.mStateMapSymbolTable.numSymbols();
            for (i = 0; i < numStates; ++i) {
                for (int j = 0; j < numStates; ++j) {
                    objOut.writeDouble((float)this.mEstimator.transitProb(i, j));
                }
            }
            for (i = 0; i < numStates; ++i) {
                String state = this.mEstimator.mStateMapSymbolTable.idToSymbol(i);
                this.mEstimator.emissionLm(state).compileTo(objOut);
            }
            for (i = 0; i < numStates; ++i) {
                objOut.writeDouble(this.mEstimator.startProb(i));
            }
            for (i = 0; i < numStates; ++i) {
                objOut.writeDouble(this.mEstimator.endProb(i));
            }
        }
    }
}

