import sys
import os
import argparse
import random 

import torch
import torch.optim as optim
import matplotlib.pyplot as plt
import nltk
from nltk.corpus import stopwords
from typing import Tuple, Dict, List, Any

from allennlp.data.dataset_readers.stanford_sentiment_tree_bank import \
    StanfordSentimentTreeBankDatasetReader
from allennlp.data import Instance, DataLoader
from allennlp.data.vocabulary import Vocabulary
from allennlp.models import Model, BasicClassifier
from allennlp.modules.seq2vec_encoders import PytorchSeq2VecWrapper, CnnEncoder
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
from allennlp.modules.token_embedders.embedding import _read_pretrained_embeddings_file
from allennlp.modules.token_embedders import Embedding, PretrainedTransformerEmbedder
from allennlp.nn.util import get_text_field_mask
from allennlp.training.metrics import CategoricalAccuracy
from allennlp.training.trainer import Trainer
from allennlp.common.util import lazy_groups_of
from allennlp.data.token_indexers import SingleIdTokenIndexer, PretrainedTransformerMismatchedIndexer
from allennlp.nn.util import move_to_device
from allennlp.interpret.saliency_interpreters import SaliencyInterpreter, SimpleGradient
from allennlp.predictors import Predictor
from allennlp.data.batch import Batch
from allennlp.data.samplers import BucketBatchSampler
import torch.nn.functional as F
from allennlp.data.tokenizers import PretrainedTransformerTokenizer

from adversarial_grads.util.misc import compute_rank, get_stop_ids, create_labeled_instances
from adversarial_grads.util.model_data_helpers import get_model, get_sst_reader

EMBEDDING_TYPE = "glove" # what type of word embeddings to use

class HLoss(torch.nn.Module):
  def __init__(self):
    super(HLoss, self).__init__()

  def forward(self, x):
    b = F.softmax(x, dim=1) * F.log_softmax(x, dim=1)
    b = b.sum()
    return b 

class SST_Attacker:
    def __init__(self, model, reader, train_data, dev_data, vocab, args):
        self.model = model
        self.reader = reader
        self.args = args
        self.predictor = Predictor.by_name('text_classifier')(self.model, self.reader)
        self.simple_gradient_interpreter = SimpleGradient(self.predictor)

        # Setup training instances
        self.train_data = train_data
        self.model_name = args.model_name
        self.batch_size = args.batch_size
        self.batched_training_instances = [train_data.instances[i:i + self.batch_size] for i in range(0, len(train_data), self.batch_size)]
        self.batched_dev_instances = [dev_data.instances[i:i + 32] for i in range(0, len(dev_data), 32)]
        self.dev_data = dev_data 
        self.vocab = vocab 
        self.loss = args.loss 
        self.embedding_op = args.embedding_op
        self.normalization = args.normalization 
        self.normalization2 = args.normalization2
        self.learning_rate = args.learning_rate
        self.lmbda = args.lmbda
        self.cuda = args.cuda 
        self.importance = args.importance 
        self.criterion = HLoss()
        self.exp_num = args.exp_num
        self.stop_words = set(stopwords.words('english'))

        if self.loss == "MSE":
            self.loss_function = torch.nn.MSELoss() 

        # self.optimizer = torch.optim.Adam(trainable_modules.parameters(), lr=self.learning_rate)
        self.optimizer = optim.Adam(model.parameters(), lr=self.learning_rate)

        exp_desc = """This experiment (number #{}) used the following hyper parameters:
        - Model: {}
        - Batch size: {}
        - Learning rate: {}
        - Lambda: {}
        - Loss function: {}
        - Embedding Operator: {}
        - Normalization: {}
        - Normalization2: {}
        - Cuda enabled: {}
        - Importance: {}
    """.format(
        self.exp_num, 
        self.model_name, 
        self.batch_size, 
        self.learning_rate, 
        self.lmbda, self.loss, 
        self.embedding_op, 
        self.normalization, 
        self.normalization2,  
        self.cuda, 
        self.importance
    )

        outdir = "sst_attack_experiments"

        if not os.path.exists(outdir):
            print("Creating directory with name:", outdir)
            os.mkdir(outdir)

        exp_dir = os.path.join(outdir, "experiment_{}".format(self.exp_num)) 
        if not os.path.exists(exp_dir):
            print("Creating directory with name:", exp_dir)
            os.makedirs(exp_dir)

        # contains info about the hyper parameters for this experiment
        self.exp_file_name = os.path.join(exp_dir, "exp.txt")
        # normalized gradients vs. number of updates
        self.grad_file_name = os.path.join(exp_dir, "grad.txt")
        # stop word gradient rank vs. number of updates
        self.grad_rank_file_name = os.path.join(exp_dir, "grad_rank.txt")

        # first token attribution on the dev set vs. number of updates
        self.first_token_attribution_dev_file_name = os.path.join(exp_dir, "first_token_attribution_dev.txt")
        # avg gradient rank on the dev set vs. number of updates
        self.avg_first_token_grad_rank_dev_file_name = os.path.join(exp_dir, "avg_first_token_grad_rank_dev.txt")
        # avg first token grad value vs. number of updates
        self.avg_first_token_grad_value_dev_file_name = os.path.join(exp_dir, "avg_first_token_grad_value_dev.txt")

        # last token attribution on the dev set vs. number of updates
        self.last_token_attribution_dev_file_name = os.path.join(exp_dir, "last_token_attribution_dev.txt")
        # avg gradient rank on the dev set vs. number of updates
        self.avg_last_token_grad_rank_dev_file_name = os.path.join(exp_dir, "avg_last_token_grad_rank_dev.txt")
        # avg last token grad value vs. number of updates
        self.avg_last_token_grad_value_dev_file_name = os.path.join(exp_dir, "avg_last_token_grad_value_dev.txt")

        # entropy vs. number of updates 
        self.entropy_dev_file_name = os.path.join(exp_dir, "entropy_dev.txt")
        # entropy loss vs. number of updates
        self.entropy_loss_file_name = os.path.join(exp_dir, "entropy_loss.txt")
        # stop word gradient loss vs. number of updates
        self.grad_loss_file_name = os.path.join(exp_dir, "grad_loss.txt")
        # stop word total loss vs. number of updates
        self.total_loss_file_name = os.path.join(exp_dir, "total_loss.txt")
        # output probs vs. number of updates
        self.output_probs_file_name = os.path.join(exp_dir, "output_probs.txt")
        # output logits vs. number of updates
        self.output_logits_file_name = os.path.join(exp_dir, "output_logits.txt")
        # raw gradients (got rid of embedding dimension tho) vs. number of updates
        self.raw_grads_file_name = os.path.join(exp_dir, "raw_gradients.txt")
        # stopword attribution on the dev set vs. number of updates
        self.stop_word_attribution_dev_file_name = os.path.join(exp_dir, "stop_word_attribution_dev.txt")

        # Remove any existing files for this directory
        files = [
            self.exp_file_name, 
            self.grad_file_name, 
            self.grad_rank_file_name, 
            self.first_token_attribution_dev_file_name,
            self.avg_first_token_grad_rank_dev_file_name,
            self.avg_first_token_grad_value_dev_file_name,
            self.last_token_attribution_dev_file_name,
            self.avg_last_token_grad_rank_dev_file_name,
            self.avg_last_token_grad_value_dev_file_name,
            self.entropy_dev_file_name,
            self.entropy_loss_file_name,
            self.grad_loss_file_name,
            self.total_loss_file_name,
            self.output_probs_file_name,
            self.output_logits_file_name,
            self.raw_grads_file_name,
            self.stop_word_attribution_dev_file_name
        ]

        for f in files: 
            if os.path.exists(f):
                os.remove(f)

        with open(self.exp_file_name, "w") as f: 
            f.write(exp_desc)

    def attack(self):
        # indicate intention for model to train
        self.model.train()

        self.record_metrics(0, None, None, None, None, None, None, None, None)
        self.model.train()

        # shuffle the data
        random.shuffle(self.batched_training_instances)
        lowest_grad_loss = 1000
        for epoch in range(40):
            for i, training_instances in enumerate(self.batched_training_instances, 1):
                print("Iter #{}".format(i))
                print(torch.cuda.memory_summary(device=0, abbreviated=True))
                
                stop_ids = [] 
                if self.importance == 'stop_token':
                    for instance in training_instances:
                        stop_ids.append(get_stop_ids(instance, self.stop_words))
                elif self.importance == 'first_token':
                    stop_ids.append({1})
                elif self.importance == "first_high_last_low":
                    for instance in training_instances:
                        print("last token", len(instance.fields['tokens']) - 2)
                        stop_ids.append({1, len(instance.fields['tokens']) - 2})

                data = Batch(training_instances)
                data.index_instances(self.vocab)
                model_input = data.as_tensor_dict()
                print("model input", model_input)
                model_input = move_to_device(model_input, cuda_device=0) if self.cuda else model_input
                outputs = self.model(**model_input)
                loss = outputs['loss']

                new_instances = create_labeled_instances(self.predictor, outputs, training_instances, self.cuda)  

                # get gradients and add to the loss
                entropy_loss = (1/self.batch_size) * self.criterion(outputs['probs'])
                print("entropy requires grad", entropy_loss.requires_grad)
                gradients, raw_gradients = self.simple_gradient_interpreter.sst_interpret_from_instances(
                    new_instances, 
                    self.embedding_op, 
                    self.normalization, 
                    self.normalization2, 
                    self.cuda, 
                    higher_order_grad=True
                )
                
                loss = 0
                batch_rank = []
                grad_batch_idx = 0
                for grad, raw_grad in zip(gradients, raw_gradients): 
                    if self.importance == 'first_token':
                        # loss takes in arrays, not integers so we have to make target into tensor
                        grad_val = grad[1].unsqueeze(0)
                        print("first token grad val 1", grad_val)

                        grad_loss = -1 * torch.abs(grad_val)
                    elif self.importance == 'stop_token':
                        grad_val = torch.sum(torch.abs(grad[stop_ids[grad_batch_idx]])).unsqueeze(0)
                        print("stop token grad val", grad_val)

                        grad_loss = -1 * torch.abs(grad_val)
                    elif self.importance == 'first_high_last_low':
                        first_grad = grad[1].unsqueeze(0)
                        last_grad = grad[-2].unsqueeze(0)

                        grad_loss = -1 * torch.abs(first_grad) + torch.abs(last_grad)

                    # compute rank
                    if self.importance == "first_token":
                        stop_ids_set = set(stop_ids[0])
                    elif self.importance == "stop_token":
                        stop_ids_set = set(stop_ids[grad_batch_idx])
                    elif self.importance == "first_high_last_low":
                        stop_ids_set = set(stop_ids[grad_batch_idx])

                    rank = compute_rank(grad, stop_ids_set)
                    batch_rank.append(rank)

                    # compute loss 
                    loss += grad_loss + self.lmbda * entropy_loss

                    grad_batch_idx += 1

                loss /= self.batch_size

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
                if grad_loss < lowest_grad_loss: 
                    print("saving model ...")
                    print("loss is", grad_loss)
                    lowest_grad_loss = grad_loss  
                    model_dir = "sst_attack_models"
                    if not os.path.exists(model_dir):
                        print("Creating directory with name:", model_dir)
                        os.mkdir(model_dir)
                    
                    exp_dir = os.path.join(model_dir, "experiment_{}".format(self.exp_num)) 
                    if not os.path.exists(exp_dir):
                        print("Creating directory with name:", exp_dir)
                        os.makedirs(exp_dir)
    
                    with open(os.path.join(exp_dir, "model.th"), 'wb') as f:
                        torch.save(self.model.state_dict(), f)
                    self.vocab.save_to_files(os.path.join(exp_dir, "vocab"))

                if i % 50 == 0:
                    self.record_metrics(i, entropy_loss, grad_loss, batch_rank, gradients, loss, outputs['probs'], outputs['logits'], raw_gradients)
                    self.model.train()
                    
                if i % 200 == 0:
                    model_dir = "sst_attack_models"
                    if not os.path.exists(model_dir):
                        print("Creating directory with name:", model_dir)
                        os.mkdir(model_dir)

                    exp_dir = os.path.join(model_dir, "experiment_{}".format(self.exp_num)) 
                    if not os.path.exists(exp_dir):
                        print("Creating directory with name:", exp_dir)
                        os.makedirs(exp_dir)

                    with open(os.path.join(exp_dir, "model_iter{}_epoch{}.th".format(i, epoch)), 'wb') as f:
                        torch.save(self.model.state_dict(), f)
        
    def record_metrics(
        self,
        iter: int,
        entropy_loss,
        grad_loss, 
        rank: List[int],
        gradients, 
        loss, 
        output_probs, 
        output_logits, 
        raw_gradients
    ) -> None:       
        self.model.eval() # model should be in eval() already, but just in case

        total_grad_rank = 0

        total_ent = 0    

        total_stop_word_attribution = 0   
        total_stop_word_grad_value = 0

        total_first_token_attribution = 0
        total_first_token_grad_rank = 0
        total_first_token_grad_value = 0

        total_last_token_attribution = 0
        total_last_token_grad_rank = 0
        total_last_token_grad_value = 0

        for i, batch in enumerate(self.batched_dev_instances): 
            print(i)
            print(torch.cuda.memory_summary(device=0, abbreviated=True))
            data = Batch(batch)
            data.index_instances(self.vocab)
            model_input = data.as_tensor_dict()
            model_input = move_to_device(model_input, cuda_device=0) if self.cuda else model_input
            with torch.no_grad(): 
                outputs = self.model(**model_input)

            new_instances = create_labeled_instances(self.predictor, outputs, batch, self.cuda)
            grads, raw_grads = self.simple_gradient_interpreter.sst_interpret_from_instances(
                new_instances, 
                self.embedding_op, 
                self.normalization, 
                self.normalization2, 
                self.cuda, 
                higher_order_grad=False
            )
            
            if self.importance == 'stop_token':
                # calculate attribution of stop tokens in all sentences
                # of the batch 
                stop_ids = []
                for instance in new_instances:
                    stop_ids.append(get_stop_ids(instance, self.stop_words))
                
                for j, grad in enumerate(grads):
                    total_stop_word_attribution += torch.sum(torch.abs(grad[stop_ids[j]])).detach()

            if self.importance == 'first_token':
                for j, grad in enumerate(grads): 
                    total_first_token_attribution += torch.abs(torch.sum(grad[1]).detach())
                    total_first_token_grad_rank += compute_rank(grad, {1})[0]
                    total_first_token_grad_value += torch.abs(raw_grads[j][1])

            if self.importance == 'first_high_last_low':
                for j, grad in enumerate(grads): 
                    total_first_token_attribution += torch.abs(torch.sum(grad[1]).detach())
                    total_first_token_grad_rank += compute_rank(grad, {1})[0]
                    total_first_token_grad_value += torch.abs(raw_grads[j][1])

                    total_last_token_attribution += torch.abs(torch.sum(grad[-2]).detach())
                    total_last_token_grad_rank += compute_rank(grad, {len(grad) - 2})[0]
                    total_last_token_grad_value += torch.abs(raw_grads[j][-2])
                
            total_ent += self.criterion(outputs['probs'])

        avg_entropy = total_ent/len(self.dev_data)

        avg_stop_word_attribution = total_stop_word_attribution/len(self.dev_data)

        avg_first_token_attribution = total_first_token_attribution/len(self.dev_data)
        avg_first_token_grad_rank = total_first_token_grad_rank/len(self.dev_data)
        avg_first_token_grad_value = total_first_token_grad_value/len(self.dev_data)

        avg_last_token_attribution = total_last_token_attribution/len(self.dev_data)
        avg_last_token_grad_rank = total_last_token_grad_rank/len(self.dev_data)
        avg_last_token_grad_value = total_last_token_grad_value/len(self.dev_data)

        with open(self.entropy_dev_file_name, "a") as f:
            f.write("Iter #{}: {}\n".format(iter, avg_entropy))

        # Stop word files
        with open(self.stop_word_attribution_dev_file_name, "a") as f:
            f.write("Iter #{}: {}\n".format(iter, avg_stop_word_attribution))

        # First token files 
        with open(self.first_token_attribution_dev_file_name, "a") as f:
            f.write("Iter #{}: {}\n".format(iter, avg_first_token_attribution))
        with open(self.avg_first_token_grad_rank_dev_file_name, "a") as f:
            f.write("Iter #{}: {}\n".format(iter, avg_first_token_grad_rank))
        with open(self.avg_first_token_grad_value_dev_file_name, "a") as f:
            f.write("Iter #{}: {}\n".format(iter, avg_first_token_grad_value))

        # Last token files 
        with open(self.last_token_attribution_dev_file_name, "a") as f:
            f.write("Iter #{}: {}\n".format(iter, avg_last_token_attribution))
        with open(self.avg_last_token_grad_rank_dev_file_name, "a") as f:
            f.write("Iter #{}: {}\n".format(iter, avg_last_token_grad_rank))
        with open(self.avg_last_token_grad_value_dev_file_name, "a") as f:
            f.write("Iter #{}: {}\n".format(iter, avg_last_token_grad_value))
        
        if iter != 0:
            with open(self.entropy_loss_file_name, "a") as f: 
                f.write("Iter #{}: {}\n".format(iter, entropy_loss))
            with open(self.grad_loss_file_name, "a") as f: 
                f.write("Iter #{}: {}\n".format(iter, grad_loss))
            with open(self.grad_rank_file_name, "a") as f:
                f.write("Iter #{}: {}\n".format(iter, rank))
            with open(self.grad_file_name, "a") as f:
                f.write("Iter #{}: {}\n".format(iter, gradients))
            with open(self.total_loss_file_name, "a") as f: 
                f.write("Iter #{}: {}\n".format(iter, loss))
            with open(self.output_probs_file_name, "a") as f:
                f.write("Iter #{}: {}\n".format(iter, output_probs))
            with open(self.output_logits_file_name, "a") as f:
                f.write("Iter #{}: {}\n".format(iter, output_logits))
            with open(self.raw_grads_file_name, "a") as f: 
                f.write("Iter #{}: {}\n".format(iter, raw_gradients)) 

def main():
    args = argument_parsing()
    print(args)

    reader = get_sst_reader(args.model_name)
    train_data = reader.read('https://s3-us-west-2.amazonaws.com/allennlp/datasets/sst/train.txt')
    dev_data = reader.read('https://s3-us-west-2.amazonaws.com/allennlp/datasets/sst/dev.txt')

    vocab = Vocabulary.from_instances(train_data)
    train_data.index_with(vocab)
    dev_data.index_with(vocab)

    model = get_model(args.model_name, vocab, args.cuda, transformer_dim=256)

    fine_tuner = SST_Attacker(model, reader, train_data, dev_data, vocab, args)
    fine_tuner.attack()
    
def argument_parsing():
    parser = argparse.ArgumentParser(description='One argparser')
    parser.add_argument('--model_name', type=str, choices=['CNN', 'LSTM', 'BERT'], help='Which model to use')
    parser.add_argument('--batch_size', type=int, help='Batch size')
    parser.add_argument('--learning_rate', type=float, help='Learning rate')
    parser.add_argument('--lmbda', type=float, help='Lambda of regularized loss')
    parser.add_argument('--exp_num', type=int, help='The experiment number')
    parser.add_argument('--loss', type=str, help='Loss function')
    parser.add_argument('--embedding_op', type=str, choices=['dot', 'l2'], help='Dot product or l2 norm')
    parser.add_argument('--normalization', type=str, choices=['l1', 'l2', 'none'], help='L1 norm or l2 norm')
    parser.add_argument('--normalization2', type=str, choices=['l1', 'l2', 'none'], help='L1 norm or l2 norm')
    parser.add_argument('--cuda', dest='cuda', action='store_true', help='Cuda enabled')
    parser.add_argument('--no-cuda', dest='cuda', action='store_false', help='Cuda disabled')
    parser.add_argument('--importance', type=str, choices=['first_token', 'stop_token', 'first_high_last_low'], help='Where the gradients should be high')
    args = parser.parse_args()
    return args

if __name__ == '__main__':
    main()
