import math
import matplotlib.pyplot as plt
import os
import seaborn as sns
import numpy
from sklearn.manifold import TSNE
import os
import random
import pandas as pd
import fitlog
import numpy as np

from tqdm import tqdm
from transformers import (XLMRobertaTokenizer, BertTokenizer)
from argparse import ArgumentParser

from utils.preprocess.Sampler import Sampler
from metric import Evaluator
import torch

from utils.tool import Batch
from models.contrastive_learning import ContrastiveLearning as CLModel
from models.base_model import BaseModel as BaseModel

class Visualizer:
    def __init__(self, args):
        self.args = args
        random.seed(self.args.seed)
        np.random.seed(self.args.seed)
        torch.manual_seed(self.args.seed)
        torch.cuda.manual_seed_all(self.args.seed)
        self.saved_model_name = f"{self.args.lang}_{self.args.saved_model_name}_{self.args.base_model}_{self.args.model_arch}_batch{self.args.batch_size}_lr{self.args.lr}_seqlen{self.args.max_seq_length}_seed{self.args.seed} "
        self.saved_encoder_name = f"en_{self.args.saved_model_name}_{self.args.model_arch}_batch{self.args.batch_size}_lr{self.args.lr}_seqlen{self.args.max_seq_length}_seed{self.args.seed}"
        self.max_seq_length = self.args.max_seq_length
        if self.args.use_cosda:
            self.saved_model_name += f"_cosda{self.args.cosda_rate}"
            self.saved_encoder_name += f"_cosda{self.args.cosda_rate}"
        cosda_lang = []
        if self.args.lang == "en":
            cosda_lang = ["de", "es", "fr", "hi", "ja", "pt", "tr", "zh"]
        elif self.args.lang == "zh":
            cosda_lang = ["hi", "ja", "tr", "zh"]
        base_tokenizers = {'bert': BertTokenizer,
                           'XLMRoberta': XLMRobertaTokenizer,
                           'MLP': BertTokenizer,
                           'BiLSTM': BertTokenizer,
                           }
        base = {
            'bert': 'bert-base-multilingual-uncased',
            'XLMRoberta': 'sentence-transformers/paraphrase-xlm-r-multilingual-v1',
            'MLP': 'bert-base-multilingual-uncased',
            'BiLSTM': 'bert-base-multilingual-uncased',
        }
        tokenizer = base_tokenizers[self.args.base_model].from_pretrained(base[self.args.base_model])
        self.train_sampler = Sampler(args, cosda_lang,tokenizer)
        self.languages = ["EN", "DE", "ES", "FR", "HI", "JA", "PT", "TR", "ZH"]
        if not self.args.draw:
            # =========train dataset===========
            self.train_sampler.load_data(self.args.data_dir, "EN", "train", shuffle=True)
            self.train_features = self.train_sampler.convert_examples_to_features()

            self.train_steps = self.args.epochs * len(self.train_features)
            if self.args.contrastive_learning:
                self.model = CLModel(args=self.args,
                                     num_slots=len(self.train_sampler.slot2idx),
                                     num_intents=len(self.train_sampler.intents)
                                     )
            else:
                self.model = BaseModel(args=self.args,
                                       num_slots=len(self.train_sampler.slot2idx),
                                       num_intents=len(self.train_sampler.intents)
                                       )

            self.load_model()
            self.train_features_dic = {}
            for lang in self.languages:
                self.train_sampler.load_data(self.args.data_dir, lang, "train")
                train_feature = self.train_sampler.convert_examples_to_features()
                self.train_features_dic[lang] = train_feature
            self.idx2slot = {v: k for k, v in self.train_sampler.slot2idx.items()}

    def load_model(self):
        checkpoint_dir = self.args.saved_model_dir + self.args.load_model_name

        if self.args.load_weights:
            if self.args.restore_from == None:
                model_CKPT = torch.load(checkpoint_dir)
            else:
                model_CKPT = torch.load(self.args.saved_model_dir + self.args.restore_from)
            self.model.load_state_dict(model_CKPT['state_dict'], False)

    def visualize(self, x, color, name):
        x = x.detach().cpu().numpy()

        print('t-SNE start')
        tsne = TSNE(n_components=2, init='pca', learning_rate=self.args.tSNE_lr).fit_transform(x)
        print('t-SNE finished')

        plt.figure(self.args.load_model_name, figsize=(8, 8))
        if self.args.use_color == "language":
            label = [self.languages[i[0]] for i in color]
        elif self.args.use_color == "intent":
            label=[self.train_sampler.intents[i[0]] for i in color]
        dataframe = pd.DataFrame({'x': tsne[:, 0], 'y': tsne[:, 1],'label':label})
        dataframe.to_csv(os.path.join(self.args.vision_dir, name + str(self.args.saved_model_name)+"_"+str(self.args.tSNE_lr) + ".csv"), index=False, sep=',')


    def draw(self,name):
        print("drawing...")
        plt.figure(figsize=(6, 6.5),dpi=512)
        plt.axis([-100, 100, -100, 100])
        data = pd.read_csv(os.path.join(self.args.vision_dir, name + ".csv"))
        g=sns.scatterplot(x="x", y="y",alpha=0.5,size="y", data=data,sizes=(10, 10), hue='label')
        sns.set(style='whitegrid', )
        leg = g.legend()
        leg.set_bbox_to_anchor([1, 0.5])  # coordinates of lower left of bounding box
        leg._loc = 2  # if required you can set the loc

        plt.subplots_adjust(top=0.7,right=0.7,bottom=0.1)


        print('saving...')
        plt.savefig(os.path.join(self.args.vision_dir,name +".pdf"))
        plt.show()
        plt.close()
        print('finished')


    def evaluate(self, lang, index=0):
        self.model.eval()
        ev = Evaluator()
        with torch.no_grad():
            iterator = tqdm(Batch.to_list(self.train_features_dic[lang], self.args.batch_size))
            output = None
            color = []
            for batch in iterator:
                for data in batch:
                    if self.args.use_color == "intent":
                        color.append([data['intent_id']])
                    elif self.args.use_color == "language":
                        color.append([index])
                    else:
                        color.append([index])
                        index = index + 1
                out = self.model.forward(batch, evaluate=True)
                if self.args.use_metric == "intent":
                    temp = out.intent_logits
                elif self.args.use_metric == "cls":
                    temp = out.cls
                elif self.args.use_metric == "slot":
                    temp = torch.mean(out.slot_logits, dim=1).reshape(out.slot_logits.shape[0], -1)
                else:
                    slot = torch.mean(out.slot_logits, dim=1).reshape(out.slot_logits.shape[0], -1)
                    temp = torch.cat((out.intent_logits, slot), dim=1)
                if output == None:
                    output = temp
                else:
                    output = torch.cat((output, temp), dim=0)
        return index, color, output

    def visual(self):
        if self.args.method == 'word':
            word_dict = self.train_sampler.get_dict()
            batch = []
            color = []
            i = 0
            for index, word in enumerate(word_dict['de']):
                if i > 50:
                    break
                flag = True
                for l in word_dict:
                    if word not in word_dict[l]:
                        flag = False
                        break
                if flag:
                    i = i + 1
                    batch.append([word])
                    color.append(index)
                    for l in word_dict:
                        batch.append([word_dict[l][word][0]])
                        color.append(index)
            self.tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-uncased')
            features = []
            for word in batch:
                pad = self.convert_tokens_to_ids_with_padding(word)
                features.append({'original': pad, 'positive': pad, 'intent_id': [0]})
            out = self.model.forward(features, evaluate=True)
            embed = out.embedded[:, 1, :].reshape(len(batch), -1)
            self.visualize(embed, color, 'word_')
        elif self.args.method == 'sentence':
            embedded = None
            index = 0
            c=None
            for lang in self.languages:
                index, color, output = self.evaluate(lang, index)
                index=index+1
                if embedded is None:
                    embedded = output
                    c=color
                else:
                    embedded = torch.cat((embedded, output), dim=0)
                    c.extend(color)
            self.visualize(embedded, c, 'sentence_')

    def convert_tokens_to_ids_with_padding(self, words):
        pad_label_id = -100
        label_ids = []
        input_ids = []
        if self.args.base_model == 'bert' or self.args.base_model == 'MLP' or self.args.base_model == 'BiLSTM':
            input_ids.extend(self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize('[CLS]')))
        elif self.args.base_model == 'XLMRoberta':
            input_ids.extend(self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize('<s>')))
        label_ids = [pad_label_id] + label_ids
        for word in zip(words):
            temp = self.tokenizer.convert_tokens_to_ids(word)
            input_ids.extend(temp)
            label_ids.extend([0] + [pad_label_id] * (len(temp) - 1))
        special_tokens_count = 2 if self.args.base_model == 'XLMRoberta' else 1
        if len(input_ids) > self.max_seq_length - special_tokens_count:
            input_ids = input_ids[: (self.max_seq_length - special_tokens_count)]
            label_ids = label_ids[: (self.max_seq_length - special_tokens_count)]
        if self.args.base_model == 'XLMRoberta':
            input_ids.extend(self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize('</s>')))
            label_ids = label_ids + [pad_label_id]
        pad_token_id = self.tokenizer.pad_token_id
        pad_token_segment_id = self.tokenizer.pad_token_type_id
        segment_ids = [0] * (len(input_ids))
        input_mask = [1] * len(input_ids)
        padding_length = self.max_seq_length - len(input_ids)
        input_ids += [pad_token_id] * padding_length
        input_mask += [0] * padding_length
        segment_ids += [pad_token_segment_id] * padding_length

        label_ids += [pad_label_id] * padding_length
        assert len(input_ids) == self.max_seq_length and len(input_mask) == self.max_seq_length and len(
            segment_ids) == self.max_seq_length and len(label_ids) == self.max_seq_length
        return {'input_ids': input_ids, 'attention_mask': input_mask, 'token_type_ids': segment_ids,
                'label_ids': label_ids}


if __name__ == "__main__":
    fitlog.set_log_dir('./logs/')


    def str2bool(v):
        if v.lower() in ('yes', 'true', 't', 'y', '1'):
            return True
        elif v.lower() in ('no', 'false', 'f', 'n', '0'):
            return False
        else:
            return False


    parser = ArgumentParser()
    parser.add_argument('--seed', type=int, default=11111)
    parser.add_argument('--draw', type=str2bool, default=False)
    parser.add_argument('--method', type=str, default='word')
    parser.add_argument('--use_metric', type=str, default='intent')
    parser.add_argument('--use_color', type=str, default='intent')
    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument('--scale', type=float, default=30)

    parser.add_argument('--show_interval', type=int, default=1)
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--negative_num', type=int, default=4)
    parser.add_argument('--lr', type=float, default=1.8e-6)
    parser.add_argument('--tSNE_lr', type=float, default=1000)
    parser.add_argument('--slot_loss_coef', type=float, default=1)
    parser.add_argument('--cls_lr', type=float, default=None)
    parser.add_argument('--model_arch', type=str, default="joint_model",
                        choices=["joint_model", "stack_propagationv1", "stack_propagationv2", "stack_propagationbo"])
    parser.add_argument('--dropout_prob', type=float, default=0.1)
    parser.add_argument('--early_stop', type=str2bool, default=False)
    parser.add_argument('--patience', type=int, default=3)

    parser.add_argument('--saved_model_dir', type=str, default="saved_model/")
    parser.add_argument('--data_dir', type=str, default="./MultiATIS_toy.v0.1/data/ldc")  # todo: toy data
    parser.add_argument('--load_model_name', type=str, default="")
    parser.add_argument('--load_weights', type=str2bool, default=False)
    parser.add_argument('--save_encoder_weights', type=str2bool, default=False)
    parser.add_argument('--load_encoder_weights', type=str2bool, default=False)
    parser.add_argument('--load_encoder_dir', type=str, default=None)
    parser.add_argument('--train', type=bool, default=True)
    parser.add_argument('--use_cosda', type=str2bool, default=False)
    parser.add_argument('--cosda_rate', type=float, default=0.5)
    parser.add_argument('--max_seq_length', type=int, default=128)
    parser.add_argument('--sample_seed', type=int, default=128)
    parser.add_argument('--show_bar', type=str2bool, default=True)
    parser.add_argument('--restore_from', type=str, default=None)
    parser.add_argument('--saved_model_name', type=str, default='hwnlu_')
    parser.add_argument('--lang', type=str, default="en")
    parser.add_argument('--base_model', type=str, default="MLP")
    parser.add_argument('--word_dict_dir', type=str, default="./MUSE_dict/")

    parser.add_argument('--contrastive_learning', type=str2bool, default=False)
    parser.add_argument('--temperature', type=float, default=1)
    parser.add_argument('--lambda1', type=float, default=1)
    parser.add_argument('--lambda2', type=float, default=1)
    parser.add_argument('--lambda3', type=float, default=1)
    parser.add_argument('--gpu', type=str2bool, default=True)
    parser.add_argument('--output_dir', type=str, default="./out/out.log")
    parser.add_argument('--vision_dir', type=str, default="./out")

    args = parser.parse_args()
    fitlog.add_hyper(args)
    print(f"args: {args}")
    v = Visualizer(args)
    if args.draw:
        v.draw('dev/sentence_cl_cls_600.0')
    else:
        v.visual()
