import torch
import pandas as pd
import os
import glob
import random
from visualize import createHTML
from ast import literal_eval
import optparse

def read_test_df(path):
    df = pd.read_csv(os.path.join(path, 'test_df.csv'))
    return df

def process_dfs(path):
    dfs  = get_all_dfs(path,'predicted')
    test_df = read_test_df(path)
    sorted_frame = get_sorted_frame(dfs)
    sorted_frame.to_csv(os.path.join(path, 'sorted_frame.csv'))
    return dfs, test_df, sorted_frame

def get_all_dfs(path,key):
    "given path, return df list sorted by reveal extent"
    allFiles = sorted(glob.glob(path + '/*'+key+'.csv'))
    dfs = []
    for file_ in allFiles:
        df = pd.read_csv(file_, index_col=None, header=0)
        dfs.append(df)
    return dfs

def apply_inplace(df, field, func):
    return pd.concat([df.drop(field, axis=1), df[field].apply(func)], axis=1)

def get_sorted_frame(dfs):
    "conbine dfs and sort by id and revealed extent"
    frame = pd.concat(dfs)
    sorted_frame = frame.sort_values(by=['ids', 'revealed'])
    return sorted_frame


def write_to_file(index, dfs, origin, outfile):
    "write prediction samples to file"
    outfile.write(str(index) + ': ' + origin + '\n')
    for i in range(len(dfs)):
        outfile.write(dfs[i].preverbs[index] + ': ' + dfs[i].predictions[index] + '\n')

def visualize_attention(df_path, index, filename):
    text_list = []
    weights_list = []
    revealed_list = []
    prediction_list = []
    target_list = []
    conf_list = []
    test_df = pd.read_csv(df_path+'test_df.csv')
    df = pd.read_csv(df_path+'sorted_frame.csv')
    sample_df = df[df.ids == index]
    for i in range(len(sample_df)*index,len(sample_df)*(index+1)):
        text_list.append(" ".join([w if w !='<UNK>' else '(UNK)' for w in sample_df.preverbs[i].split()]))
        weights_list.append(literal_eval(sample_df.attn[i]))
        revealed_list.append(sample_df.revealed[i])
        prediction_list.append(sample_df.predictions[i])
        target_list.append(sample_df.targets[i])
        conf_list.append(round(sample_df.confidences[i],2))
    original = literal_eval(test_df.preverbs[index])+[test_df.targets[index]]
    createHTML(text_list, weights_list,revealed_list, prediction_list, target_list,conf_list,
               original, filename)
    # print("Attention visualization created for {} samples".format(len(index)))
    return

def get_visualization(df_path, out_path, num_samples, index=None):
    dfs, test_df, sorted_frame = process_dfs(df_path)
    with open(os.path.join(out_path, 'sample_output.txt'), 'a') as outfile:
        indexes = []
        if index != None:
            indexes.append(index)
        else:
            indexes = []
            for i in range(num_samples):
                indexes.append(random.randint(0, len(test_df)))

        for index in indexes:
            visualize_attention(df_path, index, os.path.join(out_path,'attn_map_%d.html'%(index)))
            origin = " ".join(literal_eval(test_df.preverbs[index])) + ' ' + test_df.targets[index]
            write_to_file(index, dfs, origin, outfile)



if __name__ == "__main__":
    parser = optparse.OptionParser()

    parser.add_option("-i", "--df_path", dest="df_path", help="input data path", type="str")
    parser.add_option("-o", "--out_path",  dest="out_path", help="output html path", type="str")
    parser.add_option("-n","--num", dest="num_samples", help= "number of samples want to generate",type="int")
    parser.add_option("-d", "--index", dest="index", help="index of specific sample", type="int", default=None)
    (opt, args) = parser.parse_args()

    if not os.path.exists(opt.out_path):
        os.makedirs(opt.out_path)

    get_visualization(opt.df_path, opt.out_path, opt.num_samples, opt.index)