from typing import Tuple, Dict, List, Any
import sys
import argparse 
import os
import argparse
import nltk
import random

import torch
import torch.optim as optim 
import matplotlib.pyplot as plt
from nltk.corpus import stopwords

from allennlp.data.dataset_readers.stanford_sentiment_tree_bank import \
    StanfordSentimentTreeBankDatasetReader
from allennlp.data import Instance
from allennlp.data.vocabulary import Vocabulary
from allennlp.models import Model, BasicClassifier
from allennlp.modules.seq2vec_encoders import PytorchSeq2VecWrapper, CnnEncoder
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
from allennlp.modules.token_embedders.embedding import _read_pretrained_embeddings_file
from allennlp.modules.token_embedders import Embedding, PretrainedTransformerEmbedder
from allennlp.data.tokenizers import PretrainedTransformerTokenizer
from allennlp.nn.util import get_text_field_mask
from allennlp.training.metrics import CategoricalAccuracy
from allennlp.training.trainer import Trainer
from allennlp.common.util import lazy_groups_of
from allennlp.data.token_indexers import SingleIdTokenIndexer, PretrainedTransformerMismatchedIndexer
from allennlp.nn.util import move_to_device
from allennlp.interpret.saliency_interpreters import SaliencyInterpreter, SimpleGradient
from allennlp.predictors import Predictor
from allennlp.data.batch import Batch
import torch.nn.functional as F

from adversarial_grads.util.misc import compute_rank, get_stop_ids, create_labeled_instances
from adversarial_grads.util.combine_model import merge_models
from adversarial_grads.util.model_creator import get_bert_model, get_model

EMBEDDING_TYPE = "glove" # what type of word embeddings to use

def get_file_config(outdir, args):
    exp_desc = """This experiment (number #{}) used the following hyper parameters:
        - Model: {}
        - Embedding Operator: {}
        - Normalization: {}
        - Normalization2: {}
        - Softmax enabled: {}
        - Cuda enabled: {}
        - Attack Model File: {}
        - Predictive Model File: {}
        - Vocab Folder: {}
    """.format(
        args.exp_num, 
        args.model_name, 
        args.embedding_op, 
        args.normalization, 
        args.normalization2, 
        args.softmax, 
        args.cuda,
        args.sharp_grad_model_file,
        args.sharp_pred_model_file,
        args.vocab_folder
    )

    class FileConfig: 
        pass 
    file_config = FileConfig()

    if not os.path.exists(outdir):
        print("Creating directory with name:", outdir)
        os.mkdir(outdir)

    exp_dir = os.path.join(outdir, "experiment_{}".format(args.exp_num)) 
    if not os.path.exists(exp_dir):
        print("Creating directory with name:", exp_dir)
        os.makedirs(exp_dir)

    # contains info about the hyper parameters for this experiment
    exp_file_name = os.path.join(exp_dir, "exp.txt")
    # avg gradient rank on the dev set vs. number of updates
    avg_first_token_grad_rank_dev_file_name = os.path.join(exp_dir, "avg_first_token_grad_rank_dev.txt")
    # stopword attribution on the dev set vs. number of updates
    stop_word_attribution_dev_file_name = os.path.join(exp_dir, "stop_word_attribution_dev.txt")
    # first token attribution on the dev set vs. number of updates
    first_token_attribution_dev_file_name = os.path.join(exp_dir, "first_token_attribution_dev.txt")
    # avg first token grad value vs. number of updates
    avg_first_token_grad_value_dev_file_name = os.path.join(exp_dir, "avg_first_token_grad_value_dev.txt")
    # accuracy vs. number of updates 
    acc_file_name = os.path.join(exp_dir, "acc.txt")

    file_config.avg_first_token_grad_rank_dev_file_name = avg_first_token_grad_rank_dev_file_name
    file_config.stop_word_attribution_dev_file_name = stop_word_attribution_dev_file_name
    file_config.first_token_attribution_dev_file_name = first_token_attribution_dev_file_name
    file_config.avg_first_token_grad_value_dev_file_name = avg_first_token_grad_value_dev_file_name
    file_config.acc_file_name = acc_file_name

    files = [
        exp_file_name, 
        avg_first_token_grad_rank_dev_file_name,
        stop_word_attribution_dev_file_name,
        first_token_attribution_dev_file_name,
        avg_first_token_grad_value_dev_file_name,
        acc_file_name
    ]

    for f in files: 
        if os.path.exists(f):
            os.remove(f)

    with open(exp_file_name, "w") as f: 
        f.write(exp_desc)

    return file_config 

def record_metrics(
    combined_model,
    sharp_grad_model,
    sharp_pred_model,
    reader, 
    dev_data,
    vocab,
    args,
    file_config 
) -> None:    
    combined_model.eval() # model should be in eval() already, but just in case
    
    combined_predictor = Predictor.by_name('text_classifier')(combined_model, reader)
    combined_simple_gradient_interpreter = SimpleGradient(combined_predictor)

    # sharp_grad_predictor = Predictor.by_name('text_classifier')(sharp_grad_model, reader)
    # sharp_grad_simple_gradient_interpreter = SimpleGradient(sharp_grad_predictor)

    # sharp_pred_predictor = Predictor.by_name('text_classifier')(sharp_pred_model, reader)
    # sharp_pred_simple_gradient_interpreter = SimpleGradient(sharp_pred_predictor)

    stop_words = set(stopwords.words('english'))

    batched_dev_instances = [dev_data.instances[i:i + 16] for i in range(0, len(dev_data), 16)]
 
    total_stop_word_attribution = 0    
    total_first_token_attribution = 0
    total_first_token_grad_rank = 0
    total_first_token_grad_value = 0

    for i, batch in enumerate(batched_dev_instances): 
        print(i)
        print(torch.cuda.memory_summary(device=0, abbreviated=True))
        data = Batch(batch)
        data.index_instances(vocab)
        model_input = data.as_tensor_dict()
        model_input = move_to_device(model_input, cuda_device=0) if args.cuda else model_input
        with torch.no_grad(): 
            combined_outputs = combined_model(**model_input)
            # sharp_grad_outputs = sharp_grad_model(**model_input)
            # sharp_pred_outputs = sharp_pred_model(**model_input)

        combined_new_instances = create_labeled_instances(combined_predictor, combined_outputs, batch, args.cuda)
        # sharp_grad_new_instances = create_labeled_instances(sharp_grad_predictor, combined_outputs, batch, args.cuda)
        # sharp_pred_new_instances = create_labeled_instances(sharp_pred_predictor, combined_outputs, batch, args.cuda)

        combined_grads, raw_combined_grads = combined_simple_gradient_interpreter.sst_interpret_from_instances(
            combined_new_instances, 
            args.embedding_op, 
            args.normalization, 
            args.normalization2, 
            args.softmax, 
            args.cuda, 
            higher_order_grad=False
        )

        # sharp_grad_grads, raw_sharp_grad_grads = sharp_grad_simple_gradient_interpreter.sst_interpret_from_instances(
        #     sharp_grad_new_instances, 
        #     args.embedding_op, 
        #     args.normalization, 
        #     args.normalization2, 
        #     args.softmax, 
        #     args.cuda, 
        #     higher_order_grad=False
        # )

        # sharp_pred_grads, _ = sharp_pred_simple_gradient_interpreter.sst_interpret_from_instances(
        #     sharp_pred_new_instances, 
        #     args.embedding_op, 
        #     args.normalization, 
        #     args.normalization2, 
        #     args.softmax, 
        #     args.cuda, 
        #     higher_order_grad=False
        # )

        # for i in range(len(combined_grads)):
        #     print("combined grads", combined_grads[i])
        #     print("sharp grad gards", sharp_grad_grads[i])
        #     print("sharp pred grads", sharp_pred_grads[i])
        #     assert torch.allclose(combined_grads[i], sharp_grad_grads[i] + sharp_pred_grads[i])
        # print("passed!")
        # exit(0)

        # calculate attribution of stop tokens in all sentences
        # of the batch 
        stop_ids = []
        for instance in combined_new_instances:
            stop_ids.append(get_stop_ids(instance, stop_words))

        for j, grad in enumerate(combined_grads): 
            total_stop_word_attribution += torch.sum(grad[stop_ids[j]]).detach()
            total_first_token_attribution += torch.abs(torch.sum(grad[1]).detach())
            total_first_token_grad_rank += compute_rank(grad, {1})[0]
            total_first_token_grad_value += torch.abs(raw_combined_grads[j][1])

    avg_stop_word_attribution = total_stop_word_attribution/len(dev_data)
    avg_first_token_attribution = total_first_token_attribution/len(dev_data)
    avg_grad_rank = total_first_token_grad_rank/len(dev_data)
    avg_first_token_grad_value = total_first_token_grad_value/len(dev_data)

    with open(file_config.stop_word_attribution_dev_file_name, "a") as f:
        f.write("{}".format(avg_stop_word_attribution))
    with open(file_config.first_token_attribution_dev_file_name, "a") as f:
        f.write("{}".format(avg_first_token_attribution))
    with open(file_config.avg_first_token_grad_rank_dev_file_name, "a") as f:
        f.write("{}".format(avg_grad_rank))
    with open(file_config.avg_first_token_grad_value_dev_file_name, "a") as f:
        f.write("{}".format(avg_first_token_grad_value))
    with open(file_config.acc_file_name, "a") as f: 
        f.write("{}".format(combined_model.get_metrics()['accuracy']))

def get_reader(model_name: str) -> StanfordSentimentTreeBankDatasetReader:
    """
    Constructs and returns a SST Dataset Reader based on the model name. 
    """
    # load the binary SST dataset.
    if model_name == 'BERT':
        bert_indexer = PretrainedTransformerMismatchedIndexer('bert-base-uncased')
        reader = StanfordSentimentTreeBankDatasetReader(
            granularity="2-class",
            token_indexers={"tokens": bert_indexer}
        )
    else: 
        single_id_indexer = SingleIdTokenIndexer(lowercase_tokens=True) # word tokenizer
        # use_subtrees gives us a bit of extra data by breaking down each example into sub sentences.
        reader = StanfordSentimentTreeBankDatasetReader(
            granularity="2-class",
            token_indexers={"tokens": single_id_indexer}
        )

    return reader 

def save_model_details(model, vocab, exp_num):
    model_dir = "sst_combined_models"
    if not os.path.exists(model_dir):
        print("Creating directory with name:", model_dir)
        os.mkdir(model_dir)

    exp_dir = os.path.join(model_dir, "experiment_{}".format(exp_num)) 
    if not os.path.exists(exp_dir):
        print("Creating directory with name:", exp_dir)
        os.makedirs(exp_dir)

    with open(os.path.join(exp_dir, "model.th"), 'wb') as f:
        torch.save(model.state_dict(), f)
    vocab.save_to_files(os.path.join(exp_dir, "vocab"))

def main():
    args = argument_parsing()
    print(args)

    reader = get_reader(args.model_name)
    train_data = reader.read('https://s3-us-west-2.amazonaws.com/allennlp/datasets/sst/train.txt')
    dev_data = reader.read('https://s3-us-west-2.amazonaws.com/allennlp/datasets/sst/dev.txt')

    vocab = Vocabulary.from_instances(train_data)
    train_data.index_with(vocab)
    dev_data.index_with(vocab)

    sharp_pred_model = get_model(args.model_name, vocab, args.cuda)
    sharp_grad_model = get_model(args.model_name, vocab, args.cuda)
    sharp_grad_model.eval()
    sharp_pred_model.eval()

    vocab = Vocabulary.from_files(args.vocab_folder)
    load_model(sharp_pred_model, args.sharp_pred_model_file)
    load_model(sharp_grad_model, args.sharp_grad_model_file)

    combined_model = merge_models(sharp_grad_model, sharp_pred_model)

    # save_model_details(combined_model, vocab, args.exp_num)

    # load model 
    print("model loaded")
    exit(0)

    if args.cuda:
        combined_model.cuda()

    file_config = get_file_config("sst_combined_experiments", args)
    record_metrics(combined_model, sharp_grad_model, sharp_pred_model, reader, dev_data, vocab, args, file_config)
    
def argument_parsing():
    parser = argparse.ArgumentParser(description='One argparser')
    parser.add_argument('--model_name', type=str, choices=['CNN', 'LSTM', 'BERT'], help='Which model to use')
    parser.add_argument('--exp_num', type=int, help='The experiment number')
    parser.add_argument('--embedding_op', type=str, choices=['dot', 'l2'], help='Dot product or l2 norm')
    parser.add_argument('--normalization', type=str, choices=['l1', 'l2', 'none'], help='L1 norm or l2 norm')
    parser.add_argument('--normalization2', type=str, choices=['l1', 'l2', 'none'], help='L1 norm or l2 norm')
    parser.add_argument('--softmax', dest='softmax', action='store_true', help='Use softmax')
    parser.add_argument('--no-softmax', dest='softmax', action='store_false', help='No softmax')
    parser.add_argument('--cuda', dest='cuda', action='store_true', help='Cuda enabled')
    parser.add_argument('--no-cuda', dest='cuda', action='store_false', help='Cuda disabled')
    parser.add_argument('--sharp_grad_model_file', type=str, help='Path to bad gradient model folder')
    parser.add_argument('--sharp_pred_model_file', type=str, help='Path to good predictive model folder')
    parser.add_argument('--vocab_folder', type=str, help='Where the vocab folder is loaded from')
    args = parser.parse_args()
    return args

if __name__ == '__main__':
    main()
