'''
    Code to run the experiments and implementations for each algorithm.

    This code was taken from (with minor changes) Cognetta et. al (2018)
'''

import random, pickle, time, gc
import numpy as np
from common import *
from copy import deepcopy
from itertools import product

'''
    An implementation for Algorithm 1

    This was taken directly from the implementation by Cognetta et. al (2018).

    One small bug in their implementation was fixed. See the comment and Appendix B.
'''

def incremental_test(P, word, trials=10):

    # setting up the DFA recognizing F(word)

    sigma = P.sigma
    delta = post_process_add_one(fail_function(word, sigma, FIRST=True))
    states = range(1, len(word)+2)
    N = len(states)
    initial = set([1])
    final = set([N])

    # table to hold the timings
    timings = []

    # λ φ
    # cache these matrices
    zero = np.zeros((P.states, P.states))  # NOTE: THIS LINE WAS CHANGED FROM THE ORIGINAL IMPLEMENTATION
                                           # In Cognetta et. al's implementation, this was zero = np.zeros(P.states)
    one = np.identity(P.states)

    # Cache the matrix for M(\Sigma^*)F, which is used to emit the infix probability
    sigma_star_final = P.matrices['sigma_star']@P.matrices['final']

    def star(mat):
        return np.linalg.inv(one-mat)

    # DP tables
    table = [[zero for i in range(N+2)] for i in range(N+2)]
    table2 = [[zero for i in  range(N+2)] for i in range(N+2)]

    for i in states:
        if i in initial:
            table[0][i] = one
        if i in final:
            table[i][N+1] = one

        for c in sigma:
            if c in delta[i]:
                for j in delta[i][c]:
                    table[i][j] = table[i][j] + P.matrices[c]


    V = P.matrices['initial']

    for k in states[:-1]:
        temp_timings = []
        for _ in range(trials):
            t = 0
            X = V                            # We use X inside the timing loop so we don't corrupt V

            start = time.time()
            s_k_k = star(table[k][k])        # holds (a^{k}{k})*. Pulled out of the loop to remove redundant computation
            X = X@s_k_k@table[k][k+1]

            for i in range(N+2):             # update the table
                t_i_k_k_k = table[i][k]@s_k_k
                for j in range(N+2):
                    table2[i][j] = table[i][j] + t_i_k_k_k@table[k][j]
            t = time.time()-start
            temp_timings.append(t)
        temp_timings.sort()
        print("Infix probability of %s: " % word[:k], (X@sigma_star_final)[0,0])
        timings.append(temp_timings[len(temp_timings)//2])
        gc.collect()
        V = V@s_k_k@table[k][k+1]            # update V after the timings for this iteration are done
        table, table2 = table2, table

    return timings

'''
    The improved version of Algorithm 1

    Notice that the algorithm is nearly identical to Algorithm 1 except for
    the for loops

'''

def fast_incremental_test(P, word, trials=10):

    # setting up the DFA recognizing F(word)

    sigma = P.sigma
    delta = post_process_add_one(fail_function(word, sigma, FIRST=True))
    states = range(1, len(word)+2)
    N = len(states)
    initial = set([1])
    final = set([N])

    # table to hold the timings
    timings = []

    # λ φ
    # cache these matrices
    zero = np.zeros((P.states, P.states))
    one = np.identity(P.states)

    # Cache the matrix for M(\Sigma^*)F, which is used to emit the infix probability
    sigma_star_final = P.matrices['sigma_star']@P.matrices['final']

    def star(mat):
        return np.linalg.inv(one-mat)

    # DP tables
    table = [[zero for i in range(N+2)] for i in range(N+2)]
    table2 = [[zero for i in  range(N+2)] for i in range(N+2)]

    for i in states:
        if i in initial:
            table[0][i] = one
        if i in final:
            table[i][N+1] = one

        for c in sigma:
            if c in delta[i]:
                for j in delta[i][c]:
                    table[i][j] = table[i][j] + P.matrices[c]


    V = P.matrices['initial']

    for k in states[:-1]:
        temp_timings = []
        for _ in range(trials):
            t = 0
            X = V                            # We use X inside the timing loop so we don't corrupt V

            start = time.time()
            s_k_k = star(table[k][k])        # holds (a^{k}{k})*. Pulled out of the loop to remove redundant computation
            X = X@s_k_k@table[k][k+1]

            

            for i in range(k+1, N+2):             # update the table
                t_i_k_k_k = table[i][k]@s_k_k
                for j in range(k+1, N+2):
                    if j > k+1:
                        table2[i][j] = table[i][j]
                    else:
                        table2[i][j] = table[i][j] + t_i_k_k_k@table[k][j]
            t = time.time()-start
            temp_timings.append(t)
        temp_timings.sort()
        print("Infix probability of %s: " % word[:k], (X@sigma_star_final)[0,0])
        timings.append(temp_timings[len(temp_timings)//2])
        gc.collect()
        V = V@s_k_k@table[k][k+1]            # update V after the timings for this iteration are done
        table, table2 = table2, table

    return timings

'''
    The Online Algorithm (Algorithm 2)

    To simulate an online setting, we take the whole string at once and build the KMP DFA
    but only expose the transitions of the first k states at iteration k.
'''
def online_incremental_test(P, word, trials=10):

    delta = post_process_add_one(fail_function(word, P.sigma, FIRST=True))
    states = range(1, len(word)+2)
    N = len(states)
    initial = set([1])
    final = set([N])

    # table to hold the timings
    timings = []

    # λ φ
    # cache these matrices
    zero = np.zeros((P.states, P.states))
    one = np.identity(P.states)
    
    # Cache the matrix for M(\Sigma^*)F, which is used to emit the infix probability
    sigma_star_final = P.matrices['sigma_star']@P.matrices['final']

    #secret_matrix = [[zero for i in range(N+2)] for i in range(N+2)]
    transition_matrix = [[zero for i in range(N+2)] for i in range(N+2)]
    transition_char_table = [[[] for i in range(N+2)] for i in range(N+2)]

    def star(mat):
        return np.linalg.inv(one-mat)

    # DP tables
    table = [[zero for i in range(N+2)] for i in range(N+2)]
    table2 = [[zero for i in range(N+2)] for i in range(N+2)]
    for i in states:
        if i in initial:
            table[0][i] = one
        if i in final:
            table[i][N+1] = one

        for c in P.sigma:
            if c in delta[i]:
                for j in delta[i][c]:
                    transition_matrix[i][j] = transition_matrix[i][j] + P.matrices[c]
                    transition_char_table[i][j].append(c)

    table[0][1] = one
    transition_matrix[0][1] = one
    table[1][1] = transition_matrix[1][1]
    table[1][2] = transition_matrix[1][2]

    for k in range(1, N):
        t = 0
        s_k_k = star(table[k][k]) # holds (a^{k}{k})*. Pulled out of the loop to remove redundant computation
        temp_timings = []
        for _ in range(trials): # the timing loop
            table2_temp = deepcopy(table2) # a temporary table to hold the values during the timing runs
            start = time.time()

            for i in range(k+2): # updating the current table
                j = k+1
                if i <= k:
                    table2_temp[i][j] = table[i][j] + table[i][k]@s_k_k@table[k][j]
                elif i == k+1:
                    if k == 0:
                        table2_temp[i][j] = transition_matrix[i][j]
                    else:
                        # retreiving the states that can be reached by a back transition from state i
                        backtransitions = [idx for idx in range(1, k+2) if len(transition_char_table[i][idx]) != 0]

                        # The sum from Lemma 4
                        table2_temp[i][j] = sum((sum((P.matrices[c] for c in transition_char_table[i][idx]), zero)@(table[idx][j] + table[idx][k]@star(table[k][k])@P.matrices[word[k-1]]) for idx in backtransitions), zero)

                        if i in backtransitions: # for the self loop on states with index greater than 1. 
                                                 # This gets called only once if there is unary prefix of length > 1 and never otherwise
                            table2_temp[i][j] += transition_matrix[i][j]
    
                    table2_temp[i][i+1] = transition_matrix[i][i+1]

            t = time.time()-start
            temp_timings.append(t)
        temp_timings.sort()

        timings.append(temp_timings[len(temp_timings)//2])
        print("Infix probability of %s: "%word[:k], (P.matrices['initial']@table2_temp[0][k+1]@sigma_star_final)[0,0])
        table, table2 = table2_temp, table
        gc.collect()
    return timings

if __name__ == '__main__':

    # setting up test alphabets
    alphabet_26 = list('abcdefghijklmnopqrstuvwxyz')
    alphabet_50 = list(product(alphabet_26, alphabet_26))[:50]

    # generating random strings from each alphabet
    word_26 = [random.choice(alphabet_26) for i in range(10)]
    word_50 = [random.choice(alphabet_50) for i in range(10)]

    # generating a random PFA with 500 states over alphabet_50
    P_50 = random_PFA(200, alphabet_50)

    # running the tests
    # trials can be set to perform each iteration many times and get the median times

    print("Probabilities\n")
    
    print("Algorithm 1")
    res = incremental_test(P_50, word_50, trials=1)
    gc.collect()
    print("\n\n")
    #print("\n==================================\n")

    print("Improved Algorithm")
    res2 = fast_incremental_test(P_50, word_50, trials=1)
    gc.collect()
    print("\n\n")
    #print("\n==================================\n")

    print("Online version")
    res3 = online_incremental_test(P_50, word_50, trials=1)
    gc.collect()

    print("\n==================================\n")
    
    print("Timings\n")

    print("Algorithm 1")
    print(res)
    print("Total: %f seconds\n"%sum(res))
    print("Improved version")
    print(res2)
    print("Total: %f seconds\n"%sum(res2))
    print("Online version")
    print(res3)
    print("Total: %f seconds" % sum(res3))