import numpy as np
from common import *
import time, gc

def incremental_test(P, word, trials=10):

    # setting up the DFA recognizing F(word)

    sigma = P.sigma
    delta = post_process_add_one(fail_function(word, sigma, FIRST=True))
    states = range(1, len(word)+2)
    N = len(states)
    initial = set([1])
    final = set([N])
    
    # table to hold the timings
    timings = []
    
    # λ φ
    # cache these matrices
    zero = np.zeros(P.states)
    one = np.identity(P.states)

    # Cache the matrix for M(\Sigma^*)F, which is used to emit the infix probability
    sigma_star_final = P.matrices['sigma_star']@P.matrices['final']

    def star(mat):
        return np.linalg.inv(one-mat)

    # DP tables
    table = [[zero for i in range(N+2)] for i in range(N+2)]
    table2 = [[zero for i in  range(N+2)] for i in range(N+2)]

    for i in states:
        if i in initial:
            table[0][i] = one
        if i in final:
            table[i][N+1] = one
        
        for c in sigma:
            if c in delta[i]:
                for j in delta[i][c]:
                    table[i][j] = table[i][j] + P.matrices[c]


    V = P.matrices['initial']
    
    for k in states[:-1]:
        t = 0
        for i in range(trials):
            X = V                            # We use X inside the timing loop so we don't corrupt V
            
            s_k_k = star(table[k][k])        # holds (a^{k}{k})*. Pulled out of the loop to remove redundant computation
            X = X@s_k_k@table[k][k+1]  
            start = time.time()
            
            print("Infix probability of %s: " % word[:k], (X@sigma_star_final)[0,0])
            
            for i in range(N+2):             # update the table
                t_i_k_k_k = table[i][k]@s_k_k
                for j in range(N+2):
                    table2[i][j] = table[i][j] + t_i_k_k_k@table[k][j]
            t += time.time()-start
        
        timings.append(t/trials)
        #gc.collect()
        V = V@s_k_k@table[k][k+1]            # update V after the timings for this iteration are done
        table, table2 = table2, table

    return timings

if __name__ == '__main__':
    P = make_ngram(2, corpus="brown_corpus/*01", verify=False)
    print("Number of states in P:", P.states)
    
    word = ["NN", "NN", "PUNC"]
    
    res = incremental_test(P, word, trials=10)
    
    print("Incremental timings", res)
    print("Incremental total time: ", sum(res))
