import os
from generate_colorize import prepare_colorize, latex_colorize

def table_2(output_dir):
    '''Qualitative examples first token'''
    outfile = open(os.path.join(output_dir, 'table_2.tex'), 'w')
    outfile.write(
'''
\\begin{table}[t]
\\small 
\\centering
\\setlength{\\tabcolsep}{2pt}
Color Legend:  \\mybox{color1}{Slight Impact}\\quad \\mybox{color7}{Heavy Impact}

\\begin{tabular}{ll}
\\toprule
'''
)
    outfile.write('\\multicolumn{2}{l}{\\bf Sentiment Analysis} \\\\\n')

    text_1 = ["a", "very", "well", "-", "made", ",", "and", "entertaining", "picture", ".", "[SEP]"]
    rows = [
            ('\\VanillaGrad', text_1, [0.09446, 0.07005, 0.01163, 0.06667, 0.05446, 0.12122, 0.01241, 0.27540, 0.04737, 0.09688, 0.13586], [0.94757, 0.00397, 0.00041, 0.00373, 0.00339, 0.00671, 0.00096, 0.01601, 0.00313, 0.00549, 0.00773]),
            ('\\SmoothGrad', text_1, [0.07989, 0.20937, 0.08815, 0.03857, 0.06612, 0.06612, 0.03306, 0.26446, 0.08815, 0.04408, 0.00000], [0.40954, 0.00596, 0.06811, 0.00681, 0.03065, 0.02725, 0.01022, 0.14304, 0.22009, 0.02725, 0.01022]),
            ('\\IntegratedGrad', text_1, [0.014, 0.015, 0.004, 0.003, 0.014, 0.043, 0.025, 0.003, 0.017, 0.009, 0.849], [0.015, 0.016, 0.002, 0.002, 0.013, 0.045, 0.025, 0.001, 0.018, 0.011, 0.846]),
        ]

    for method, text, original_scores, merged_scores in rows:
        if method != '\\VanillaGrad': outfile.write('\\addlinespace \n')
        outfile.write('\\multicolumn{2}{l}{\\em' + method + '} \\\\\n')
        outfile.write('\\fop & {} \\\\\n'.format(latex_colorize(text, original_scores)))
        outfile.write('\\fmergedfirst & {} \\\\\n'.format(latex_colorize(text, merged_scores)))
    outfile.write('\\midrule\n')

    outfile.write('\\multicolumn{2}{l}{\\bf NLI} \\\\\n')

    text_2 = ["two", "men", "are", "shouting", ".", "[SEP]", "two", "men", "are", "quiet", ".", "[SEP]"]
    rows = [
            ('\\VanillaGrad', text_2, [0.08869, 0.01355, 0.00762, 0.42750, 0.00324, 0.03869, 0.03408, 0.03378, 0.01784, 0.25169, 0.06468, 0.00564], [0.93144, 0.00141, 0.00086, 0.03048, 0.00065, 0.00265, 0.00316, 0.00242, 0.00117, 0.01931, 0.00456, 0.00022]),
            ('\\SmoothGrad', text_2, [0.02797, 0.09790, 0.00000, 0.00000, 0.02797, 0.06294, 0.04895, 0.04196, 0.02797, 0.55944, 0.01399, 0.00000], [0.54621, 0.00552, 0.00828, 0.07724, 0.05517, 0.00552, 0.00000, 0.01103, 0.01241, 0.22069, 0.04414, 0.01103]),
            ('\\IntegratedGrad', text_2, [0.065, 0.021, 0.054, 0.170, 0.047, 0.026, 0.095, 0.065, 0.111, 0.239, 0.034, 0.048], [0.053, 0.038, 0.066, 0.204, 0.019, 0.052, 0.085, 0.041, 0.051, 0.248, 0.059, 0.068]),
        ]

    for method, text, original_scores, merged_scores in rows:
        if method != '\\VanillaGrad': outfile.write('\\addlinespace \n')
        outfile.write('\\multicolumn{2}{l}{\\em' + method + '} \\\\\n')
        outfile.write('\\fop & {} \\\\\n'.format(latex_colorize(text, original_scores)))
        outfile.write('\\fmergedfirst & {} \\\\\n'.format(latex_colorize(text, merged_scores)))
    outfile.write('\\midrule\n')
    
    outfile.write('\\multicolumn{2}{l}{\\bf Reading Comprehension (Question)} \\\\\n')

    text_3 = ["Who", "stars", "in", "The", "Matrix", "?", "[SEP]"]
    rows = [
            ('\\VanillaGrad', text_3, [0.311, 0.377, 0.035, 0.091, 0.083, 0.072, 0.017], [0.869, 0.072, 0.007, 0.017, 0.016, 0.014, 0.003]),
            ('\\SmoothGrad', text_3, [0.259, 0.296, 0.093, 0.037, 0.148, 0.074, 0.046], [0.715, 0.1  , 0.05 , 0.05 , 0.052, 0.025, 0.   ]),
            ('\\IntegratedGrad', text_3, [0.418, 0.091, 0.098, 0.062, 0.022, 0.284, 0.004], [0.404, 0.096, 0.1  , 0.063, 0.022, 0.289, 0.004]),
        ]

    for method, text, original_scores, merged_scores in rows:
        if method != '\\VanillaGrad': outfile.write('\\addlinespace \n')
        outfile.write('\\multicolumn{2}{l}{\\em' + method + '} \\\\\n')
        outfile.write('\\fop & {} \\\\\n'.format(latex_colorize(text, original_scores)))
        outfile.write('\\fmergedfirst & {} \\\\\n'.format(latex_colorize(text, merged_scores)))
    outfile.write('\\midrule\n')

    outfile.write('\\multicolumn{2}{l}{\\bf Bios} \\\\\n')

    text_4 = ["in", "brazil", "she", "did", "her", "first", "steps", "in", "surgery", ".", "[SEP]"]
    rows = [
            ('\\VanillaGrad', text_4, [0.10337, 0.09871, 0.10628, 0.05294, 0.00313, 0.08227, 0.14705, 0.02063, 0.33622, 0.01582, 0.03336], [0.42869, 0.06263, 0.06828, 0.03373, 0.00215, 0.05222, 0.09353, 0.01299, 0.21463, 0.00999, 0.02101]),
            ('\\SmoothGrad', text_4, [0.07843, 0.11765, 0.13725, 0.07843, 0.07843, 0.09804, 0.00000, 0.03922, 0.05882, 0.15686, 0.07843], [0.20592, 0.12355, 0.08237, 0.01030, 0.02574, 0.02059, 0.08237, 0.04376, 0.16474, 0.08237, 0.08237]),
            ('\\IntegratedGrad', text_4, [0.021, 0.017, 0.098, 0.030, 0.222, 0.054, 0.019, 0.040, 0.083, 0.010, 0.401], [0.041, 0.009, 0.079, 0.022, 0.167, 0.037, 0.020, 0.030, 0.063, 0.020, 0.506]),
            ]

    for method, text, original_scores, merged_scores in rows:
        if method != '\\VanillaGrad': outfile.write('\\addlinespace \n')
        outfile.write('\multicolumn{2}{l}{\\em' + method + '} \\\\\n')
        outfile.write('\\fop & {} \\\\\n'.format(latex_colorize(text, original_scores)))
        outfile.write('\\fmergedfirst & {} \\\\\n'.format(latex_colorize(text, merged_scores)))

    outfile.write('\\bottomrule\n')

    outfile.write(
'''
\\end{tabular}

\\caption{Qualitative examples first token. RC question was renormalized across the question.}
\\label{tab:qualitative_examples_first_token}
\\end{table}
'''
)


def table_3(output_dir):
    '''Qualitative examples stop token'''
    outfile = open(os.path.join(output_dir, 'table_3.tex'), 'w')
    outfile.write(
'''
\\begin{table*}[t]
\\small 
\\centering
\\setlength{\\tabcolsep}{2pt}
Color Legend:  \\mybox{color1}{Slight Impact}\\quad \\mybox{color7}{Heavy Impact}

\\begin{tabular}{ll}
\\toprule
'''
)
    outfile.write('\\multicolumn{2}{l}{\\bf Sentiment Analysis} \\\\\n')

    text_1 = ["visually", "imaginative", "and", "thoroughly", "delightful", ",", "it", "takes", "us", "on", "a", "roller", "-", "coaster", "ride", "from", "innocence", "to", "experience", ".", "[SEP]"]
    rows = [
            ('\\VanillaGrad', text_1, [0.04670, 0.00096, 0.02841, 0.08283, 0.07075, 0.03372, 0.05412, 0.00280, 0.02269, 0.01371, 0.03908, 0.16688, 0.01869, 0.18155, 0.03322, 0.01516, 0.03640, 0.00633, 0.06018, 0.04100, 0.04211], [0.00157, 0.00002, 0.40137, 0.00211, 0.00273, 0.00108, 0.00134, 0.00031, 0.00065, 0.00010, 0.50683, 0.00582, 0.00047, 0.00612, 0.00096, 0.00081, 0.00088, 0.06250, 0.00139, 0.00139, 0.00132]),
            ('\\SmoothGrad', text_1, [0.08807, 0.11743, 0.02385, 0.01468, 0.01835, 0.01468, 0.08073, 0.08073, 0.00734, 0.02569, 0.02569, 0.10275, 0.02202, 0.14679, 0.01468, 0.01835, 0.02202, 0.00000, 0.08440, 0.00734, 0.05505], [0.00321, 0.01526, 0.41124, 0.00803, 0.03534, 0.00482, 0.00643, 0.00482, 0.00321, 0.00482, 0.09960, 0.02892, 0.00803, 0.01285, 0.01285, 0.01285, 0.00482, 0.29639, 0.00000, 0.00321, 0.00884]),
            ('\\IntegratedGrad', text_1, [0.013, 0.083, 0.025, 0.048, 0.033, 0.078, 0.056, 0.036, 0.010, 0.041, 0.002, 0.044, 0.079, 0.018, 0.009, 0.023, 0.088, 0.104, 0.064, 0.023, 0.092], [0.013, 0.081, 0.023, 0.045, 0.035, 0.078, 0.052, 0.035, 0.010, 0.042, 0.007, 0.045, 0.080, 0.021, 0.005, 0.024, 0.085, 0.104, 0.061, 0.025, 0.096]),
        ]

    for method, text, original_scores, merged_scores in rows:
        if method != '\\VanillaGrad': outfile.write('\\addlinespace \n')
        outfile.write('\\multicolumn{2}{l}{\\em' + method + '} \\\\\n')
        outfile.write('\\fop & {} \\\\\n'.format(latex_colorize(text, original_scores)))
        outfile.write('\\fmergedfirst & {} \\\\\n'.format(latex_colorize(text, merged_scores)))
    outfile.write('\\midrule\n')

    outfile.write('\\multicolumn{2}{l}{\\bf NLI} \\\\\n')

    text_2 = ["a", "large", ",", "gray", "elephant", "walked", "beside", "a", "herd", "of", "zebra", "##s", ".", "[SEP]", "the", "elephant", "was", "lost", ".", "[SEP]"]
    rows = [
            ('\\VanillaGrad', text_2, [0.00019, 0.01002, 0.03683, 0.04548, 0.13126, 0.03090, 0.01635, 0.00946, 0.00678, 0.01708, 0.10861, 0.00355, 0.01860, 0.04160, 0.06159, 0.18931, 0.00019, 0.15150, 0.05420, 0.05075], [0.37033, 0.00175, 0.00658, 0.00801, 0.02334, 0.00544, 0.00302, 0.29194, 0.00098, 0.00230, 0.01916, 0.00064, 0.00334, 0.00740, 0.17321, 0.03342, 0.00029, 0.02715, 0.00959, 0.00900]),
            ('\\SmoothGrad', text_2, [0.00357, 0.08556, 0.00357, 0.02852, 0.01426, 0.08556, 0.01070, 0.01070, 0.01426, 0.01070, 0.08556, 0.02496, 0.01604, 0.00000, 0.02139, 0.17112, 0.11408, 0.22816, 0.01426, 0.04278], [0.06003, 0.01656, 0.00983, 0.05796, 0.00414, 0.01656, 0.01100, 0.25514, 0.03933, 0.01449, 0.03312, 0.00828, 0.00414, 0.01604, 0.14284, 0.06624, 0.13249, 0.06624, 0.00414, 0.00414]),
            ('\\IntegratedGrad', text_2, [0.019, 0.032, 0.018, 0.018, 0.000, 0.062, 0.043, 0.005, 0.040, 0.032, 0.017, 0.031, 0.018, 0.086, 0.001, 0.104, 0.019, 0.374, 0.022, 0.021], [0.000, 0.010, 0.020, 0.007, 0.029, 0.043, 0.028, 0.013, 0.023, 0.026, 0.001, 0.025, 0.034, 0.346, 0.002, 0.087, 0.013, 0.227, 0.023, 0.008]),
        ]

    for method, text, original_scores, merged_scores in rows:
        if method != '\\VanillaGrad': outfile.write('\\addlinespace \n')
        outfile.write('\\multicolumn{2}{l}{\\em' + method + '} \\\\\n')
        outfile.write('\\fop & {} \\\\\n'.format(latex_colorize(text, original_scores)))
        outfile.write('\\fmergedfirst & {} \\\\\n'.format(latex_colorize(text, merged_scores)))
    outfile.write('\\midrule\n')
    
    outfile.write('\\multicolumn{2}{l}{\\bf Reading Comprehension (Question)} \\\\\n')

    text_3 = ["Who", "caught", "the", "touchdown", "pass", "?", "[SEP]"]
    rows = [
            ('\\VanillaGrad', text_3, [0.014, 0.222, 0.075, 0.46 , 0.131, 0.052, 0.029], [0.005, 0.042, 0.73 , 0.145, 0.049, 0.014, 0.01 ]),
            ('\\SmoothGrad', text_3, [0.056, 0.393, 0.009, 0.374, 0.009, 0.028, 0.056], [0.   , 0.14 , 0.419, 0.   , 0.048, 0.079, 0.245]),
            ('\\IntegratedGrad', text_3, [0.61 , 0.014, 0.009, 0.023, 0.011, 0.225, 0.108], [0.594, 0.023, 0.034, 0.025, 0.006, 0.217, 0.099]),
        ]

    for method, text, original_scores, merged_scores in rows:
        if method != '\\VanillaGrad': outfile.write('\\addlinespace \n')
        outfile.write('\\multicolumn{2}{l}{\\em' + method + '} \\\\\n')
        outfile.write('\\fop & {} \\\\\n'.format(latex_colorize(text, original_scores)))
        outfile.write('\\fmergedfirst & {} \\\\\n'.format(latex_colorize(text, merged_scores)))
    outfile.write('\\midrule\n')

    outfile.write('\\multicolumn{2}{l}{\\bf Bios} \\\\\n')

    text_4 = ["she", "has", "had", "many", "years", "of", "experience", "and", "did", "thousands", "of", "operations", ".", "[SEP]"]
    rows = [
            ('\\VanillaGrad', text_4, [0.26159, 0.03713, 0.06811, 0.02078, 0.10693, 0.00844, 0.00159, 0.00714, 0.00989, 0.07130, 0.07281, 0.18179, 0.00973, 0.14041], [0.05629, 0.00728, 0.01448, 0.00404, 0.02162, 0.07727, 0.00112, 0.44827, 0.00200, 0.01517, 0.28093, 0.03789, 0.00262, 0.02928]),
            ('\\SmoothGrad', text_4, [0.14493, 0.04348, 0.05797, 0.11594, 0.11594, 0.02899, 0.11594, 0.08696, 0.05797, 0.11594, 0.00000, 0.00000, 0.04348, 0.01449], [0.09865, 0.02691, 0.02242, 0.05381, 0.01794, 0.04484, 0.07175, 0.15247, 0.01794, 0.17937, 0.05381, 0.14350, 0.04484, 0.00897]),
            ('\\IntegratedGrad', text_4, [0.492, 0.015, 0.041, 0.039, 0.013, 0.010, 0.022, 0.014, 0.001, 0.092, 0.025, 0.014, 0.043, 0.162], [0.588, 0.040, 0.049, 0.015, 0.011, 0.019, 0.033, 0.027, 0.005, 0.054, 0.030, 0.012, 0.013, 0.100]),
        ]

    for method, text, original_scores, merged_scores in rows:
        if method != '\\VanillaGrad': outfile.write('\\addlinespace \n')
        outfile.write('\\multicolumn{2}{l}{\\em' + method + '} \\\\\n')
        outfile.write('\\fop & {} \\\\\n'.format(latex_colorize(text, original_scores)))
        outfile.write('\\fmergedfirst & {} \\\\\n'.format(latex_colorize(text, merged_scores)))

    outfile.write('\\bottomrule\n')

    outfile.write(
'''
\\end{tabular}

\\caption{Qualitative examples stop token. RC question was renormalized across the question.}
\\label{tab:qualitative_examples_stop_token}
\\end{table*}
'''
)

def colorize_example(output_dir):
    outfile = open(os.path.join(output_dir, 'illustrative_example.tex'), 'w')

    text = ['why', 'make', 'a', 'documentary', 'about', 'these', 'marginal', 'historical', 'figures', '?', '[SEP]']
    rows = [
            (
                text, 
                [0.34230, 0.02116, 0.04192, 0.11603, 0.10397, 0.00020, 0.12693, 0.01987, 0.03363, 0.07968, 0.09080],
                [0.70125, 0.00809, 0.01908, 0.05369, 0.04803, 0.00016, 0.05715, 0.00958, 0.01500, 0.03585, 0.04144],
                [0.04769, 0.00280, 0.86601, 0.01532, 0.01522, 0.00006, 0.01761, 0.00318, 0.00520, 0.01198, 0.01261]
            ),
        ]

    for text, original_scores, first_token_scores, stop_token_scores in rows:
        outfile.write('{}\n'.format(latex_colorize(text, original_scores)))
        outfile.write('{}\n'.format(latex_colorize(text, first_token_scores)))
        outfile.write('{}\n'.format(latex_colorize(text, stop_token_scores)))

if __name__ == '__main__':
    prepare_colorize('./')
    # table_2('tables')
    # table_3('tables')
    colorize_example('tables')