from typing import Tuple, Dict, List, Any
import sys
import argparse 
import os.path
import argparse
import nltk
from nltk.corpus import stopwords
import torch
import random 
import matplotlib.pyplot as plt
import torch.optim as optim

from allennlp.data.dataset_readers.stanford_sentiment_tree_bank import \
    StanfordSentimentTreeBankDatasetReader
from allennlp.data import Instance, DataLoader
from allennlp.data.vocabulary import Vocabulary
from allennlp.models import Model, BasicClassifier
from allennlp.modules.seq2vec_encoders import PytorchSeq2VecWrapper, CnnEncoder
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
from allennlp.modules.token_embedders.embedding import _read_pretrained_embeddings_file
from allennlp.modules.token_embedders import Embedding
from allennlp.nn.util import get_text_field_mask
from allennlp.training.metrics import CategoricalAccuracy
from allennlp.training.trainer import Trainer,GradientDescentTrainer
from allennlp.common.util import lazy_groups_of
from allennlp.data.token_indexers import SingleIdTokenIndexer, PretrainedTransformerMismatchedIndexer
from allennlp.data.tokenizers import PretrainedTransformerTokenizer
from allennlp.nn.util import move_to_device
from allennlp.interpret.saliency_interpreters import SaliencyInterpreter, SimpleGradient
from allennlp.predictors import Predictor
from allennlp.data.batch import Batch
from allennlp.data.samplers import BucketBatchSampler
import torch.nn.functional as F

from adversarial_grads.util.misc import compute_rank, get_stop_ids, create_labeled_instances
from adversarial_grads.util.model_data_helpers import get_sst_reader, get_model

EMBEDDING_TYPE = "glove" # what type of word embeddings to use

class HLoss(torch.nn.Module):
  def __init__(self):
    super(HLoss, self).__init__()

  def forward(self, x):
    b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1)
    b = b.sum()
    return b 

class SST_Attacker:
    def __init__(self, model, reader, train_data, dev_data, vocab, args):
        self.model = model
        self.reader = reader
        self.args = args
        self.predictor = Predictor.by_name('text_classifier')(self.model, self.reader)
        self.simple_gradient_interpreter = SimpleGradient(self.predictor)

        # Setup training instances
        self.train_data = train_data
        self.model_name = args.model_name
        self.batch_size = args.batch_size
        self.batched_training_instances = [train_data.instances[i:i + self.batch_size] for i in range(0, len(train_data), self.batch_size)]
        self.batched_dev_instances = [dev_data.instances[i:i + 32] for i in range(0, len(dev_data), 32)]
        self.dev_data = dev_data 
        self.vocab = vocab 
        self.loss = args.loss 
        self.embedding_op = args.embedding_op
        self.normalization = args.normalization 
        self.normalization2 = args.normalization2
        self.learning_rate = args.learning_rate
        self.lmbda = args.lmbda
        self.softmax = args.softmax 
        self.cuda = args.cuda 
        self.criterion = HLoss()
        self.exp_num = args.exp_num
        self.stop_words = set(stopwords.words('english'))

        if self.loss == "MSE":
            self.loss_function = torch.nn.MSELoss() 

        # self.optimizer = torch.optim.Adam(trainable_modules.parameters(), lr=self.learning_rate)
        self.optimizer = optim.Adam(model.parameters(), lr=self.learning_rate)

        exp_desc = """This experiment (number #{}) used the following hyper parameters:
        - Model: {}
        - Batch size: {}
        - Learning rate: {}
        - Lambda: {}
        - Loss function: {}
        - Embedding Operator: {}
        - Normalization: {}
        - Normalization2: {}
        - Softmax enabled: {}
        - Cuda enabled: {}
    """.format(self.exp_num, self.model_name, self.batch_size, self.learning_rate, self.lmbda, self.loss, self.embedding_op, self.normalization, self.normalization2, self.softmax, self.cuda)

        outdir = "sst_baseline_experiments"

        if not os.path.exists(outdir):
            print("Creating directory with name:", outdir)
            os.mkdir(outdir)

        exp_dir = os.path.join(outdir, "experiment_{}".format(self.exp_num)) 
        if not os.path.exists(exp_dir):
            print("Creating directory with name:", exp_dir)
            os.makedirs(exp_dir)

        # contains info about the hyper parameters for this experiment
        self.exp_file_name = os.path.join(exp_dir, "exp.txt")
        # normalized gradients vs. number of updates
        self.grad_file_name = os.path.join(exp_dir, "grad.txt")
        # first token gradient rank vs. number of updates 
        self.first_token_grad_rank_file_name = os.path.join(exp_dir, "first_token_grad_rank.txt")
        # stop word gradient rank vs. number of updates
        self.stop_word_grad_rank_file_name = os.path.join(exp_dir, "stop_word_grad_rank.txt")
        # avg gradient rank on the dev set vs. number of updates
        self.avg_first_token_grad_rank_dev_file_name = os.path.join(exp_dir, "avg_first_token_grad_rank_dev.txt")
        # entropy vs. number of updates 
        self.entropy_dev_file_name = os.path.join(exp_dir, "entropy_dev.txt")
        # entropy loss vs. number of updates
        self.entropy_loss_file_name = os.path.join(exp_dir, "entropy_loss.txt")
        # first token gradient loss vs. number of updates
        self.first_token_grad_loss_file_name = os.path.join(exp_dir, "first_token_grad_loss.txt")
        # stop word gradient loss vs. number of updates
        self.stop_word_grad_loss_file_name = os.path.join(exp_dir, "stop_word_grad_loss.txt")
        # first token total loss vs. number of updates
        self.first_token_total_loss_file_name = os.path.join(exp_dir, "first_token_total_loss.txt")
        # stop word total loss vs. number of updates
        self.stop_word_total_loss_file_name = os.path.join(exp_dir, "stop_word_total_loss.txt")
        # output probs vs. number of updates
        self.output_probs_file_name = os.path.join(exp_dir, "output_probs.txt")
        # output logits vs. number of updates
        self.output_logits_file_name = os.path.join(exp_dir, "output_logits.txt")
        # raw gradients (got rid of embedding dimension tho) vs. number of updates
        self.raw_grads_file_name = os.path.join(exp_dir, "raw_gradients.txt")
        # stopword attribution on the dev set vs. number of updates
        self.stop_word_attribution_dev_file_name = os.path.join(exp_dir, "stop_word_attribution_dev.txt")
        # first token attribution on the dev set vs. number of updates
        self.first_token_attribution_dev_file_name = os.path.join(exp_dir, "first_token_attribution_dev.txt")
        # accuracy vs. number of updates 
        self.acc_file_name = os.path.join(exp_dir, "acc.txt")

        # Remove any existing files for this directory
        files = [
            self.exp_file_name, 
            self.grad_file_name, 
            self.first_token_grad_rank_file_name, 
            self.stop_word_grad_rank_file_name,
            self.avg_first_token_grad_rank_dev_file_name,
            self.entropy_dev_file_name,
            self.entropy_loss_file_name,
            self.first_token_grad_loss_file_name,
            self.stop_word_grad_loss_file_name,
            self.first_token_total_loss_file_name,
            self.stop_word_total_loss_file_name,
            self.output_probs_file_name,
            self.output_logits_file_name,
            self.raw_grads_file_name,
            self.stop_word_attribution_dev_file_name,
            self.first_token_attribution_dev_file_name,
            self.acc_file_name
        ]

        for f in files: 
            if os.path.exists(f):
                os.remove(f)

        with open(self.exp_file_name, "w") as f: 
            f.write(exp_desc)

    def attack(self):
        # indicate intention for model to train
        self.model.train()

        self.record_metrics(0, None, None, None, None, None, None, None, None, None, None, None)
        self.model.train()

        # shuffle the data
        random.shuffle(self.batched_training_instances)
        lowest_grad_loss = 1000
        for epoch in range(4):
            for i, training_instances in enumerate(self.batched_training_instances, 1):
                print("Iter #{}".format(i))
                print(torch.cuda.memory_summary(device=0, abbreviated=True))
                
                # NOTE: this support for higher batch sizes will currently 
                # break the following code! 
                stop_ids = [] 
                for instance in training_instances:
                    stop_ids.append(get_stop_ids(instance, self.stop_words))

                data = Batch(training_instances)
                data.index_instances(self.vocab)
                model_input = data.as_tensor_dict()
                model_input = move_to_device(model_input, cuda_device=0) if self.cuda else model_input
                outputs = self.model(**model_input)
                loss = outputs['loss']

                new_instances = create_labeled_instances(self.predictor, outputs, training_instances, self.cuda)  

                # get gradients and add to the loss
                entropy_loss = self.criterion(outputs['probs'])
                print("entropy requires grad", entropy_loss.requires_grad)
                gradients, raw_gradients = self.simple_gradient_interpreter.sst_interpret_from_instances(
                    new_instances, 
                    self.embedding_op, 
                    self.normalization, 
                    self.normalization2, 
                    self.softmax, 
                    self.cuda, 
                    higher_order_grad=True
                )
                
                # NOTE: get rid of batch dimension, this should be done 
                # differently for higher batch sizes 
                gradients = gradients[0]
                raw_gradients = raw_gradients[0]
                # loss takes in arrays, not integers so we have to make target into tensor
                print("zero element gradients", gradients[0].unsqueeze(0).requires_grad)
                print("grads", gradients)
                
                first_token_grad_val = gradients[0].unsqueeze(0)
                print("first token grad val", first_token_grad_val)

                stop_word_grad_val = torch.sum(gradients[stop_ids]).unsqueeze(0)
                print("stop token grad val", stop_word_grad_val)

                first_token_grad_loss = self.loss_function(first_token_grad_val, torch.ones(1).cuda() if self.cuda else torch.ones(1))
                stop_word_grad_loss = self.loss_function(stop_word_grad_val, torch.ones(1).cuda() if self.cuda else torch.ones(1))

                # compute rank
                stop_ids_set = set(stop_ids[0])
                stop_word_rank = compute_rank(gradients, stop_ids_set)
                first_token_rank = compute_rank(gradients, {0})

                # compute loss 
                stop_word_total_loss = stop_word_grad_loss + self.lmbda * entropy_loss
                first_token_total_loss = first_token_grad_loss + self.lmbda * entropy_loss 

                # self.optimizer.zero_grad()
                # loss.backward()
                # self.optimizer.step()
                
                if i % 10 == 0:
                    self.record_metrics(
                        i, 
                        entropy_loss, 
                        first_token_grad_loss, 
                        stop_word_grad_loss,
                        first_token_rank, 
                        stop_word_rank,
                        gradients, 
                        first_token_total_loss, 
                        stop_word_total_loss, 
                        outputs['probs'], 
                        outputs['logits'], 
                        raw_gradients
                    )
        
    def record_metrics(
        self,
        iter: int,
        entropy_loss,
        first_token_grad_loss, 
        stop_word_grad_loss,
        first_token_rank: List[int],
        stop_word_rank: List[int],
        gradients, 
        first_token_total_loss,
        stop_word_total_loss,  
        output_probs, 
        output_logits, 
        raw_gradients
    ) -> None:       
        self.model.eval() # model should be in eval() already, but just in case
        # iterator = BucketBatchSampler(self.dev_data, batch_size=64, sorting_keys=["tokens"])
        # dev_loader = DataLoader(self.dev_data, batch_sampler=iterator)

        total_ent = 0    
        total_stop_word_attribution = 0    
        total_first_token_attribution = 0

        total_first_token_grad_rank = 0

        for i, batch in enumerate(self.batched_dev_instances): 
            print(i)
            print(torch.cuda.memory_summary(device=0, abbreviated=True))
            data = Batch(batch)
            data.index_instances(self.vocab)
            model_input = data.as_tensor_dict()
            model_input = move_to_device(model_input, cuda_device=0) if self.cuda else model_input
            with torch.no_grad(): 
                outputs = self.model(**model_input)

            new_instances = create_labeled_instances(self.predictor, outputs, batch, self.cuda)
            grads, _ = self.simple_gradient_interpreter.sst_interpret_from_instances(
                new_instances, 
                self.embedding_op, 
                self.normalization, 
                self.normalization2, 
                self.cuda, 
                higher_order_grad=False
            )
            # calculate attribution of stop tokens in all sentences
            # of the batch 
            stop_ids = []
            for instance in new_instances:
                stop_ids.append(get_stop_ids(instance, self.stop_words))

            for j, grad in enumerate(grads): 
                total_stop_word_attribution += torch.sum(grad[stop_ids[j]]).detach()
                total_first_token_attribution += torch.sum(grad[0]).detach()
                total_first_token_grad_rank += compute_rank(grad, {0})[0]
                
            total_ent += self.criterion(outputs['probs'])

        avg_entropy = total_ent/len(self.dev_data)
        avg_stop_word_attribution = total_stop_word_attribution/len(self.dev_data)
        avg_first_token_attribution = total_first_token_attribution/len(self.dev_data)
        avg_first_token_grad_rank = total_first_token_grad_rank/len(self.dev_data)

        with open(self.entropy_dev_file_name, "a") as f:
            f.write("Iter #{}: {}\n".format(iter, avg_entropy))
        with open(self.stop_word_attribution_dev_file_name, "a") as f:
            f.write("Iter #{}: {}\n".format(iter, avg_stop_word_attribution))
        with open(self.first_token_attribution_dev_file_name, "a") as f:
            f.write("Iter #{}: {}\n".format(iter, avg_first_token_attribution))
        with open(self.acc_file_name, "a") as f: 
            f.write("Iter #{}: {}\n".format(iter, self.model.get_metrics()['accuracy']))
        
        if iter != 0:
            with open(self.entropy_loss_file_name, "a") as f: 
                f.write("Iter #{}: {}\n".format(iter, entropy_loss))
            with open(self.first_token_grad_loss_file_name, "a") as f: 
                f.write("Iter #{}: {}\n".format(iter, first_token_grad_loss))
            with open(self.stop_word_grad_loss_file_name, "a") as f: 
                f.write("Iter #{}: {}\n".format(iter, stop_word_grad_loss))
            with open(self.first_token_grad_rank_file_name, "a") as f:
                f.write("Iter #{}: {}\n".format(iter, first_token_rank))
            with open(self.stop_word_grad_rank_file_name, "a") as f:
                f.write("Iter #{}: {}\n".format(iter, stop_word_rank))
            with open(self.avg_first_token_grad_rank_dev_file_name, "a") as f:
                f.write("Iter #{}: {}\n".format(iter, avg_first_token_grad_rank))
            with open(self.grad_file_name, "a") as f:
                f.write("Iter #{}: {}\n".format(iter, gradients))
            with open(self.first_token_total_loss_file_name, "a") as f: 
                f.write("Iter #{}: {}\n".format(iter, first_token_total_loss))
            with open(self.stop_word_total_loss_file_name, "a") as f: 
                f.write("Iter #{}: {}\n".format(iter, stop_word_total_loss))
            with open(self.output_probs_file_name, "a") as f:
                f.write("Iter #{}: {}\n".format(iter, output_probs))
            with open(self.output_logits_file_name, "a") as f:
                f.write("Iter #{}: {}\n".format(iter, output_logits))
            with open(self.raw_grads_file_name, "a") as f: 
                f.write("Iter #{}: {}\n".format(iter, raw_gradients)) 

def save_model_details(model, vocab, exp_num):
    model_dir = "sst_baseline_models"
    if not os.path.exists(model_dir):
        print("Creating directory with name:", model_dir)
        os.mkdir(model_dir)

    exp_dir = os.path.join(model_dir, "experiment_{}".format(exp_num)) 
    if not os.path.exists(exp_dir):
        print("Creating directory with name:", exp_dir)
        os.makedirs(exp_dir)

    with open(os.path.join(exp_dir, "model.th"), 'wb') as f:
        torch.save(model.state_dict(), f)
    vocab.save_to_files(os.path.join(exp_dir, "vocab"))

def main():
    args = argument_parsing()
    print(args)

    # load the binary SST dataset.
    reader = get_sst_reader(args.model_name, use_subtrees=True)
    train_data = reader.read('https://s3-us-west-2.amazonaws.com/allennlp/datasets/sst/train.txt')
    dev_data = reader.read('https://s3-us-west-2.amazonaws.com/allennlp/datasets/sst/dev.txt')

    vocab = Vocabulary.from_instances(train_data)
    train_data.index_with(vocab)
    dev_data.index_with(vocab)

    train_sampler = BucketBatchSampler(train_data, batch_size=32, sorting_keys=["tokens"])
    dev_sampler = BucketBatchSampler(dev_data, batch_size=32, sorting_keys=["tokens"])
    train_data_loader = DataLoader(train_data, batch_sampler=train_sampler)
    dev_data_loader = DataLoader(dev_data, batch_sampler=dev_sampler)

    model = get_model(args.model_name, vocab, args.cuda)

    if args.train: 
        optimizer = optim.Adam(model.parameters(), lr=(2e-5 if args.model_name=='BERT' else 1e-3))
        trainer = GradientDescentTrainer(
            model=model,
            optimizer=optimizer,
            data_loader=train_data_loader,
            validation_data_loader=dev_data_loader,
            num_epochs=8,
            patience=1,
            cuda_device=(0 if args.cuda else -1)
        )
        trainer.train() 

        save_model_details(model, vocab, args.exp_num)
    
def argument_parsing():
    parser = argparse.ArgumentParser(description='One argparser')
    parser.add_argument('--model_name', type=str, choices=['CNN', 'LSTM', 'BERT'], help='Which model to use')
    parser.add_argument('--batch_size', type=int, help='Batch size')
    parser.add_argument('--learning_rate', type=float, help='Learning rate')
    parser.add_argument('--lmbda', type=float, help='Lambda of regularized loss')
    parser.add_argument('--exp_num', type=int, help='The experiment number')
    parser.add_argument('--loss', type=str, help='Loss function')
    parser.add_argument('--embedding_op', type=str, choices=['dot', 'l2'], help='Dot product or l2 norm')
    parser.add_argument('--normalization', type=str, choices=['l1', 'l2', 'none'], help='L1 norm or l2 norm')
    parser.add_argument('--normalization2', type=str, choices=['l1', 'l2', 'none'], help='L1 norm or l2 norm')
    parser.add_argument('--softmax', dest='softmax', action='store_true', help='Use softmax')
    parser.add_argument('--no-softmax', dest='softmax', action='store_false', help='No softmax')
    parser.add_argument('--cuda', dest='cuda', action='store_true', help='Cuda enabled')
    parser.add_argument('--no-cuda', dest='cuda', action='store_false', help='Cuda disabled')
    parser.add_argument('--train', dest='train', action='store_true', help='Baseline will be trained')
    parser.add_argument('--no-train', dest='train', action='store_false', help='Baseline will not be trained')
    args = parser.parse_args()
    return args

if __name__ == '__main__':
    main()
