import numpy as np
from common import *
import gc, time
from itertools import product

class DFA(object):
    def __init__(self, states=None, sigma=None, delta=None, initial=None, final=None):
        self.sigma = sigma
        self.states = states
        self.delta = delta
        self.initial = initial
        self.final = final
    
    def parse(self, w):
        cur = self.initial
        for c in w:
            if c not in self.delta[cur]:
                return False
            cur = self.delta[cur][c]
        return cur in self.final

def intersect_PFA_DFA_helper(P, D):
    new_states = P.states*D.states #states will be (p, d)
    sigma = P.sigma
    new_initial = {}
    new_final = {}
    new_delta = {}

    state_map = {}
    counter = 0
    for s in product(range(P.states), range(D.states)):
        state_map[s] = counter
        counter += 1

    for p,d in state_map:
        if d in D.final:
            new_final[state_map[p,d]] = P.matrices['final'][p,0]
        else:
            new_final[state_map[p,d]] = 0
        
        if d == D.initial:
            new_initial[state_map[p,d]] = P.matrices['initial'][0,p]
        else:
            new_initial[state_map[p,d]] = 0
    for d in range(D.states):
        for c in D.delta[d]:
            matrix = P.matrices[c]
            d_prime = D.delta[d][c]
            for p in P.forward_transitions:
                if c in P.forward_transitions[p]:
                    for p_prime in P.forward_transitions[p][c]:
                        if state_map[(p,d)] not in new_delta:
                            new_delta[state_map[(p,d)]] = {}
                        if c not in new_delta[state_map[(p,d)]]:
                            new_delta[state_map[(p,d)]][c] = {}
                        new_delta[state_map[(p,d)]][c][state_map[(p_prime, d_prime)]] = matrix[p,p_prime]
    return new_states, sigma, new_delta, new_initial, new_final 

#### intersection parts
def intersect_test(P, word, trials=10):
    tags = P.sigma
    times = []
    for i in range(len(word)):
        w = word[:i+1]
        delta = post_process(fail_function(w, tags))
        states = len(delta)
        D = DFA(states=states, sigma=tags, delta=delta, initial=0, final=set([states-1]))
        t = 0

        for i in range(trials):
            start = time.time()
            new_states, sigma, new_delta, new_initial, new_final  = intersect_PFA_DFA_helper(P, D)
            W = WA(states=new_states, sigma=sigma, delta=new_delta, initial=new_initial, final=new_final, just_sigma=True)
            print("Infix probability of %s: " % w, (W.matrices['initial']@W.matrices['sigma_star']@W.matrices['final'])[0,0])
            t += time.time()-start
            del W, new_states, sigma, new_delta, new_initial, new_final
            gc.collect()
        times.append(t/trials)
    return times

if __name__ == '__main__':
    P = make_ngram(2, corpus="brown_corpus/*01", verify=False)
    print("Number of states in P:", P.states)
    
    word = ["NN", "NN", "PUNC"]
    
    res = intersect_test(P, word, trials=10)
    
    print("Intersect timings", res)
    print("Intersect total time: ", sum(res))
