import copy, glob, nltk, time, math
import numpy as np

'''
implementation of the fail function to generate a DFA for \mathcal{F}(w)
'''
def fail_function(W, charset, FIRST=False):
    F=[]
    F.append(0)
    i=1
    j=0
    while i<len(W):
        if W[i]==W[j]:
            F.append(j+1)
            i=i+1
            j=j+1
        elif j>0:
            j=F[j-1]
        else:
            F.append(0)
            i=i+1
    i=0
    FF=[{}]
    for k in charset:
        if k==W[i]:
            FF[i][W[i]]=1
        else:
            FF[i][k]=0
    while i<len(W):
        FF.append({})
        FF[i][W[i]]=i+1
        scharset=copy.deepcopy(charset)
        scharset.remove(W[i])
        for c in scharset:
            if i!=0:
                j=F[i-1]
            else:
                j=0
            notfound=True
            while notfound:
                if c==W[j]:
                    FF[i][c]=j+1
                    notfound=False
                elif j>0:
                    j=F[j-1]
                else:
                    FF[i][c]=0
                    notfound=False
        i+=1
    if not FIRST:
        for c in charset:
            FF[i][c]=i
    return FF

'''
These two methods postprocess an input DFA specification based on if it will be used for intersection
or the state elimination method (they are indexed differently)
'''
def post_process(FF):
    states = {}
    for i in range(len(FF)):
        states[i] = FF[i]
    return states

def post_process_add_one(FF):
    states = {}
    for i in range(len(FF)):
        states[i+1] = FF[i]
        for temp in states[i+1]:
            states[i+1][temp] += 1
            states[i+1][temp] = [states[i+1][temp]]
    return states

'''
Basic weighted automaton class.
Weighted automata are a superset of PFAs, so we simply have a parameter to verify if a weighted automaton (over the real semiring)
is a PFA

 - states: int, number of states
 - sigma: list, alphabet
 - matrices: dict of numpy matrices, this is optional if you have already created the transition matrices
 - delta: dict, contains an adjacency list. delta[i][c][j] is the weight of transition q_i to q_j reading c
 - initial: dict, contains the initial weights. initial[i] is the initial weight of state q_i
 - final: dict, contains the final weights. final[i] is the final weight of state q_i
 
 isPFA, verify, and just_sigma are optional parameters used only in the testing environment

'''
class WA(object):
    def __init__(self, states=None, sigma=None, matrices=None, delta=None, initial=None, final=None, isPFA=False, verify=False, just_sigma=False):
        self.sigma = sigma
        self.states = states
        if any(x is None for x in [delta,initial,final]) and matrices is None:
            raise Exception
        elif matrices is None:
            self.matrices = {}
            self.matrices['initial'] = np.zeros((1,self.states))
            self.matrices['final'] = np.zeros((self.states,1))
            for i in range(states):
                self.matrices['final'][i,0] = final[i]
                self.matrices['initial'][0,i] = initial[i]
            self.forward_transitions = delta    
            
            identity = np.identity(self.states)

            if just_sigma:
                temp = np.zeros((self.states,self.states))
                for c in self.sigma:
                    for q in delta:
                        if c in delta[q]:
                            for p in delta[q][c]:
                                temp[q,p] += delta[q][c][p]
                self.matrices['sigma'] = temp

            else:
                for c in self.sigma:
                    temp = np.zeros((self.states,self.states))
                    for q in delta:
                        if c in delta[q]:
                            for p in delta[q][c]:
                                temp[q,p] = delta[q][c][p]
                    self.matrices[c] = temp

                self.matrices['sigma'] = sum(self.matrices[c] for c in self.sigma)
            self.matrices['id'] = identity
            self.matrices['identity'] = identity
            self.matrices['sigma_star'] = np.linalg.inv(identity-self.matrices['sigma'])
        else:
            self.matrices = matrices

        if isPFA and verify:
            if not self.verify_PFA():
                raise Exception

    def verify_PFA(self):
        if not math.isclose(sum(self.matrices['initial'][0,s] for s in range(self.states)),1.0):
            # print("initial")
            # print(sum(self.matrices['initial'][0,s] for s in range(self.states)))
            return False

        for s in range(self.states):
            if not math.isclose(sum(self.matrices[c][s,q] for c in self.sigma for q in range(self.states)) + self.matrices['final'][s,0],1):
                # print("transitions",s)
                # print(sum(self.matrices[c][s,q] for c in self.sigma for q in range(self.states)) + self.matrices['final'][s,0])
                return False
        
        if not math.isclose(self.matrices['initial']@self.matrices['sigma_star']@self.matrices['final'], 1.0):
            # print('degenerate')
            return False
        return True

    def weight(self, w):

        temp = self.matrices['initial']
        for c in w:
            temp = temp@self.matrices[c]
    
        return (temp@self.matrices['final'])[0,0]

'''
make_ngram has only been tested with the brown corpus that is included with this demo.

It should work for anything that has one sentence per line and the same alphabet (and punctuation) as the brown corpus.
'''
def make_ngram(n, corpus="brown_corpus/*", verify=False):
    tags = set()
    punctuation = set(['$', "''", '(', ')', ',', '.', ':', '``',"$"])

    # chain = {("START", "START"):{}}
    start = tuple(["START" for _ in range(n)])
    chain = {start:{}}
    for text in glob.glob(corpus):
        f = open(text, 'r')
        for line in f:
            line = line.strip()
            if len(line) > 1:
                #print(line)
                text = nltk.word_tokenize(line)
                tokens = nltk.pos_tag(text)
                #print(tokens)
                temp = []
                add_ele=""
                for token in tokens:
                    ##Soon chan fix
                    if token[1] == "CC":
                        add_ele="CCX"
                    elif token[1] in punctuation:
                        add_ele="PUNC"
                    elif token[1][-1] == "$":
                        add_ele=token[1][:-1]+"X"
                    else:
                        add_ele=token[1]
                    tags.add(add_ele)
                    temp.append(add_ele)
                cur = start
                ##Soon chan fix
                for i in temp:
                    if i not in chain[cur]:
                        chain[cur][i] = 0
                    chain[cur][i] += 1
                    cur = tuple(list(cur[1:])+[i])
                    if cur not in chain:
                        chain[cur] = {}
                if "END" not in chain[cur]:
                    chain[cur]["END"] = 0
                chain[cur]["END"] += 1
        f.close()

    new_chain = {}
    for state in chain:
        new_chain[state] = {}
        for temp in chain[state]:
            new_state = tuple(list(state[1:])+[temp])
            new_chain[state][new_state] = chain[state][temp]

    chain = new_chain
    # print(new_chain)
    state_map = {}
    delta = {}
    initial = {}
    final = {}
    alphabet = set()

    seen_states = 0
    for state in chain:
        state_map[state] = seen_states
        final[seen_states] = 0
        initial[seen_states] = 0
        seen_states += 1

    for state in chain:
        delta[state_map[state]] = {}
        for w in chain[state]:
            char = w[-1]
            if char == 'END':
                final[state_map[state]] = chain[state][w]
            else:
                alphabet.add(char)
                if char not in delta[state_map[state]]:
                    delta[state_map[state]][char] = {}

                delta[state_map[state]][char][state_map[w]] = chain[state][w]
   
    for state in delta:
        total = final[state]
        for c in delta[state]:
            for q in delta[state][c]:
                total += delta[state][c][q]
        
        final[state] /= total
        for c in delta[state]:
            for q in delta[state][c]:
                delta[state][c][q] /= total

    initial[state_map[tuple(['START']*n)]] = 1
    #return tags
    return WA(states=len(state_map), sigma = sorted(list(alphabet)), delta=delta, initial=initial, final=final, isPFA=True, verify=verify)
