'''
    Code to construct a KMP DFA and an implementation of weighted automata.

    This code was taken from (with minor changes) Cognetta et. al (2018)
'''


import copy, glob, nltk, time, math
import numpy as np

'''
implementation of the fail function to generate a KMP DFA
'''
def fail_function(W, charset, FIRST=False):
    F=[]
    F.append(0)
    i=1
    j=0
    while i<len(W):
        if W[i]==W[j]:
            F.append(j+1)
            i=i+1
            j=j+1
        elif j>0:
            j=F[j-1]
        else:
            F.append(0)
            i=i+1
    i=0
    FF=[{}]
    for k in charset:
        if k==W[i]:
            FF[i][W[i]]=1
        else:
            FF[i][k]=0
    while i<len(W):
        FF.append({})
        FF[i][W[i]]=i+1
        scharset=copy.deepcopy(charset)
        scharset.remove(W[i])
        for c in scharset:
            if i!=0:
                j=F[i-1]
            else:
                j=0
            notfound=True
            while notfound:
                if c==W[j]:
                    FF[i][c]=j+1
                    notfound=False
                elif j>0:
                    j=F[j-1]
                else:
                    FF[i][c]=0
                    notfound=False
        i+=1
    if not FIRST:
        for c in charset:
            FF[i][c]=i
    return FF

'''
These two methods postprocess an input DFA specification based on if it will be used for intersection
or the state elimination method (they are indexed differently)
'''
def post_process(FF):
    states = {}
    for i in range(len(FF)):
        states[i] = FF[i]
    return states

def post_process_add_one(FF):
    states = {}
    for i in range(len(FF)):
        states[i+1] = FF[i]
        for temp in states[i+1]:
            states[i+1][temp] += 1
            states[i+1][temp] = [states[i+1][temp]]
    return states

'''
Basic weighted automaton class.
Weighted automata are a superset of PFAs, so we simply have a parameter to verify if a weighted automaton (over the real semiring)
is a PFA
'''
class WA(object):
    def __init__(self, states=None, sigma=None, matrices=None, delta=None, initial=None, final=None, isPFA=False, verify=False, just_sigma=False):
        self.sigma = sigma
        self.states = states

        if any(x is None for x in [delta,initial,final]) and matrices is None:
            raise Exception
        elif matrices is None:

            self.matrices = {}
            self.matrices['initial'] = np.zeros((1,self.states))
            self.matrices['final'] = np.zeros((self.states,1))
            for i in range(states):
                self.matrices['final'][i,0] = final[i]
                self.matrices['initial'][0,i] = initial[i]
            self.forward_transitions = delta    
            
            identity = np.identity(self.states)

            if just_sigma:
                temp = np.zeros((self.states,self.states))
                for c in self.sigma:
                    for q in delta:
                        if c in delta[q]:
                            for p in delta[q][c]:
                                temp[q,p] += delta[q][c][p]
                self.matrices['sigma'] = temp

            else:
                for c in self.sigma:
                    temp = np.zeros((self.states,self.states))
                    for q in delta:
                        if c in delta[q]:
                            for p in delta[q][c]:
                                temp[q,p] = delta[q][c][p]
                    self.matrices[c] = temp

                self.matrices['sigma'] = sum(self.matrices[c] for c in self.sigma)
            self.matrices['id'] = identity
            self.matrices['identity'] = identity
            self.matrices['sigma_star'] = np.linalg.inv(identity-self.matrices['sigma'])
        else:
            self.matrices = matrices

        if isPFA and verify:
            if not self.verify_PFA():
                raise Exception

    def verify_PFA(self):
        if not math.isclose(sum(self.matrices['initial'][0,s] for s in range(self.states)),1.0):
            # print("initial")
            # print(sum(self.matrices['initial'][0,s] for s in range(self.states)))
            return False

        for s in range(self.states):
            if not math.isclose(sum(self.matrices[c][s,q] for c in self.sigma for q in range(self.states)) + self.matrices['final'][s,0],1):
                # print("transitions",s)
                # print(sum(self.matrices[c][s,q] for c in self.sigma for q in range(self.states)) + self.matrices['final'][s,0])
                return False
        
        if not math.isclose(self.matrices['initial']@self.matrices['sigma_star']@self.matrices['final'], 1.0):
            # print('degenerate')
            return False
        return True

    def weight(self, w):

        temp = self.matrices['initial']
        for c in w:
            temp = temp@self.matrices[c]
    
        return (temp@self.matrices['final'])[0,0]

'''
    Function go generate random PFAs with a given state space and alphabet

    This simply constructs a random initial vector, final vector, and matrix for each character
    and then normalizes their weights to meet the PFA conditions.
'''

def random_PFA(states, alphabet):
    matrices = {}
    matrices['initial'] = np.random.rand(1, states)
    matrices['initial'] = matrices['initial']/np.sum(matrices['initial'])

    for c in alphabet:
        matrices[c] = np.random.rand(states, states)
    matrices['final'] = np.random.rand(states, 1)
    for i in range(states):
        total = sum(matrices[c][i,j] for c in alphabet for j in range(states)) + matrices['final'][i,0]
        for c in alphabet:
            for j in range(states):
                matrices[c][i,j] /= total
        matrices['final'][i] /= total
    matrices['sigma'] = sum(matrices[c] for c in alphabet)
    matrices['sigma_star'] = np.linalg.inv(np.identity(states)-matrices['sigma'])
    P = WA(states, alphabet, matrices=matrices)
    return P
