import os
import warnings

warnings.filterwarnings(
    "ignore", ".*Trying to infer the `batch_size` from an ambiguous collection.*"
)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # or any {'0', '1', '2'}

import argparse
import glob
import os
import json
import time
import logging
import random
import re
import math
from itertools import chain
from string import punctuation
from tqdm.auto import tqdm
import pickle
import string
import pandas as pd
import numpy as np
import torch
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from sklearn.model_selection import train_test_split
from termcolor import colored
import textwrap
from torch.optim import Optimizer
from transformers_2 import (
    AdamW,
    T5ForConditionalGeneration,
    T5Tokenizer,
    T5Config,
    get_linear_schedule_with_warmup
)

from pytorch_lightning.plugins import DDPPlugin
import sys
from pytorch_lightning.callbacks import ModelCheckpoint
from transformers.data.metrics.squad_metrics import compute_exact, compute_f1, normalize_answer
from statistics import mean


from pytorch_lightning.trainer.supporters import CombinedLoader
import torch.distributed as dist

pl.seed_everything(42)

os.environ['TORCH_DISTRIBUTED_DEBUG'] = 'DETAIL'

mode = 'eval'
MODEL_NAME = "google/t5-large-ssm"
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
BATCH_SIZE = 64
N_EPOCHS = 80
N_GPUS = 2


max_source_len = 25
max_target_len = 10

lora_expert_num = 2
lora_rank = 256
attn_lora_rank = 0
lora_attn_alpha = 256*4
lora_attn_attn_alpha = 0
lora_dropout = 0.1
lora_r_dropout = 0.1
lora_moe_act ='linear'
lora_moe_lambda = 1.0
lora_moe_softmax = 1

def get_exact_match(prediction, groundtruth):
    if type(groundtruth) == list:
        if len(groundtruth) == 0:
            return 0
        return np.max([get_exact_match(prediction, gt) for gt in groundtruth])
    return normalize_answer(prediction) == normalize_answer(groundtruth)



def get_f1_match(prediction, groundtruth):
    if type(groundtruth) == list:
        if len(groundtruth) == 0:
            return 0
        return np.max([get_f1_match(prediction, gt) for gt in groundtruth])
    pred_tokens = normalize_answer(prediction).split()
    truth_tokens = normalize_answer(groundtruth).split()
    # print(pred_tokens)
    # print(truth_tokens)
    if len(pred_tokens) == 0 or len(truth_tokens) == 0:
        return int(pred_tokens == truth_tokens)
    common_tokens = set(pred_tokens) & set(truth_tokens)
    if len(common_tokens) == 0:
        return 0
    prec = len(common_tokens) / len(pred_tokens)
    rec = len(common_tokens) / len(truth_tokens)

    return 2 * (prec * rec) / (prec + rec)


def normalize_answer(s):
    def remove_articles(text):
        return re.sub(r'\b(a|ab|the)\b',' ',text)
    def white_space_fix(text):
        return ' '.join(text.split())
    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)
    def lower(text):
        return text.lower()
    return white_space_fix(remove_articles(remove_punc(lower(s))))

def extract_questions_and_answers(json_dataset):
    data_rows = []

    for idx, json_data in enumerate(json_dataset):
        question = json_data['question']
        answer_text = json_data['answer']
        aliases=json_data["aliases"]
        # aligned_fact = "<extra_id_1> "+ json_data['aligned_fact']
        # question  = "<extra_id_1> " + question
        
        data_rows.append({
                "question":question,
                'aliases':aliases,
                "answer_text":answer_text
            })

    return pd.DataFrame(data_rows)



class QADataset(Dataset):

    def __init__(
        self,
        data: pd.DataFrame,
        tokenizer: T5Tokenizer,
        source_max_token_len: int = max_source_len,
        target_max_token_len: int = max_target_len
    ):

        self.tokenizer = tokenizer
        self.data = data
        self.source_max_token_len = source_max_token_len
        self.target_max_token_len = target_max_token_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index:int):
        data_row = self.data.iloc[index]

        source_encoding = tokenizer(
            data_row["question"],
            # data_row["context"],
            max_length=self.source_max_token_len,
            padding="max_length",
            truncation="only_first",
            return_attention_mask = True,
            add_special_tokens = True,
            return_tensors = "pt"
        )


        target_encoding = tokenizer(
            data_row["answer_text"],
            max_length=self.target_max_token_len,
            padding="max_length",
            truncation="only_first",
            return_attention_mask = True,
            add_special_tokens = True,
            return_tensors = "pt"
        )

        labels = target_encoding["input_ids"]
        labels[labels == 0] = -100


        return dict(
            question=data_row["question"],
            aliases=[data_row['aliases']],
            answer_text=data_row["answer_text"],
            input_ids=source_encoding["input_ids"].flatten(),
            attention_mask=source_encoding["attention_mask"].flatten(),
            labels=labels.flatten()
        )


class QADataModule(pl.LightningDataModule):

    def __init__(
            self,
            train_df: pd.DataFrame,
            val_df: pd.DataFrame,
            test_df: pd.DataFrame,
            tokenizer: T5Tokenizer,
            batch_size: int = 8,
            source_max_token_len: int = max_source_len,
            target_max_token_len: int = max_target_len
        ):
            super().__init__()
            self.batch_size = batch_size
            self.train_df = train_df
            self.val_df = val_df
            self.test_df = test_df
            self.tokenizer = tokenizer
            self.source_max_token_len = source_max_token_len
            self.target_max_token_len = target_max_token_len

    def setup(self):
        self.train_dataset = QADataset(
            self.train_df,
            self.tokenizer,
            self.source_max_token_len,
            self.target_max_token_len
        )

        self.val_dataset = QADataset(
            self.val_df,
            self.tokenizer,
            self.source_max_token_len,
            self.target_max_token_len
        )

        self.test_dataset = QADataset(
            self.test_df,
            self.tokenizer,
            self.source_max_token_len,
            self.target_max_token_len
        )

    def collate_fn(self, batch):
        questions=[b['question'] for b in batch]
        aliases=[b['aliases'] for b in batch]
        answers=[b['answer_text'] for b in batch]
        input_ids=torch.tensor([list(b['input_ids']) for b in batch])
        attention_masks=torch.tensor([list(b['attention_mask']) for b in batch])
        labels=torch.tensor([list(b['labels']) for b in batch])

        # aligned_input_ids = torch.tensor([list(b['aligned_input_ids']) for b in batch])
        # aligned_attention_masks = torch.tensor([list(b['aligned_attention_mask']) for b in batch])
        # is_aligned = [b['is_aligned'] for b in batch]
        # is_aligned=torch.tensor([b['is_aligned'] for b in batch])

        batches={
            "question":questions,
            "aliases":aliases,
            "answers":answers,
            "input_ids":input_ids,
            "attention_mask":attention_masks,
            "labels":labels,
            # "aligned_input_ids":aligned_input_ids,
            # "aligned_attention_mask":aligned_attention_masks,
            # "is_aligned":is_aligned
        }
        return batches

    def train_dataloader(self):
        return DataLoader(
        self.train_dataset,
            batch_size = BATCH_SIZE,
            shuffle = True,
            num_workers = 24,
            drop_last=False,
            collate_fn=self.collate_fn
            )

    def val_dataloader(self):
        loaders = {"a": DataLoader(
            self.val_dataset,
            batch_size = BATCH_SIZE,
            num_workers = 24,
            drop_last=False,
            collate_fn=self.collate_fn
            ),
            "b": DataLoader(
            self.test_dataset,
            batch_size = BATCH_SIZE,
            num_workers = 24,
            drop_last=False,
            collate_fn=self.collate_fn
            )}
        return CombinedLoader(loaders, "max_size_cycle")
        # return DataLoader(
        #     self.test_dataset,
        #     batch_size = BATCH_SIZE,
        #     shuffle=False,
        #     num_workers = 24,
        #     drop_last=False,
        #     collate_fn=self.collate_fn
        #     )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size = BATCH_SIZE,
            shuffle=False,
            num_workers = 24,
            drop_last=False,
            collate_fn=self.collate_fn
            )


class QAModel(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.config=T5Config.from_pretrained(MODEL_NAME,return_dict=True)
        self.config.is_lora=lora_rank
        self.config.lora_attn_alpha=lora_attn_alpha
        self.config.lora_attn_attn_alpha = lora_attn_attn_alpha
        self.config.lora_dropout=lora_dropout
        self.config.lora_r_dropout=lora_r_dropout
        self.config.lora_expert_num=lora_expert_num
        self.config.attn_is_lora = attn_lora_rank
        self.config.dropout_rate=0
        self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, config=self.config)
        self.grouped_parameters = None
        self.embedding_memory = None
        self.task_A_accs = []
        self.task_B_accs = []
        self.avg_accs = []



    def forward(self, input_ids, attention_mask, labels=None, expert_labels=None):

        output=self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            router_phase=1,
            expert_labels=expert_labels,
            # lora_used = lora_used
            )
        return output.loss,output.logits, output.avg_embedding


    def training_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        types=torch.tensor([1 for i in range(input_ids.shape[0])], device = self.device)

        # lora_used=torch.tensor(batch['lora_used'],device=self.device)

        # print(type(input_ids))
        loss, _, avg_embedding = self(input_ids, attention_mask, labels, expert_labels=types)
        loss=torch.mean(loss)

        # self.log("train_loss", loss, prog_bar=True, logger=True)
        self.log("train_loss", loss, prog_bar=True, logger=True, batch_size = BATCH_SIZE)
        return {"loss": loss, 'types': types, "pred":avg_embedding}


    def training_step_end(selfself, batch_parts):
        return batch_parts

    def training_epoch_end(self, training_step_outputs):

        if (self.current_epoch)==0:
            temp = []
            # idx_temp=[]
            for i in training_step_outputs:
                target = i['pred']#(batch,768)
                types = i['types']#(batch,)
                # idxs= i['idx']

                norm = target.norm(p=2, dim=1, keepdim=True)
                target = target.div(norm)

                for tar_, type_ in zip(target, types):
                    # if type_.tolist() == 1:
                    temp.append(tar_)#총,768
                    # idx_temp.append(idx_)
            temp = torch.stack(temp)
    
            def gather_list_and_concat(tensor):
                # tensor = torch.Tensor(list_of_nums).to('cpu')
                gather_t = [torch.ones_like(tensor) for _ in range(dist.get_world_size())]
                dist.all_gather(gather_t, tensor)
                return torch.cat(gather_t)

            self.embedding_memory = gather_list_and_concat(temp)
            with open("embedding_memory_zsREwith_new_train"+str(self.device),"wb") as f:
                pickle.dump(self.embedding_memory,f)

    def validation_step(self, batch_ab, batch_idx):
        output = []
        for i,dataset in enumerate(['a', 'b']):
            batch = batch_ab[dataset]
            input_ids = batch["input_ids"]
            attention_mask = batch["attention_mask"]
            labels = batch["labels"]
        
            task_A = torch.tensor([1 for j in range(len(batch['input_ids']))],device=self.device)
            _, pred, _ = self.model.expert_prepare(input_ids, attention_mask, labels, expert_labels=task_A)
            
            norm = pred.norm(p=2, dim=1, keepdim = True)
            pred = pred.div(norm)
            if self.embedding_memory is not None:
                score = torch.matmul(pred, self.embedding_memory.transpose(1,0))
                score, index = torch.max(score, dim=1)
                task_A = torch.where(score>=0.9, 1, 0)
                # if i == 1:
                #     # print(score)
                #     # print(task_A)
                #     assert not (0 in task_A), "SOMETHING WRONG 0 CAN NOT BE IN TASK_A"

            generated_ids = self.model.generate(
            input_ids = input_ids,
            attention_mask = attention_mask,
            num_beams=1,
            max_length=max_target_len,
            repetition_penalty=2.5,
            length_penalty=1.0,
            early_stopping=True,
            use_cache=True,
            task_A=task_A,
            )
            predictions = tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
            output.append((predictions,batch['aliases']))
            
        return output


    # def on_validation_epoch_end(self, output):
    def validation_epoch_end(self, validation_step_outputs):
        correct_A = 0
        total_num_A = 0
        correct_norm_A = 0
        correct_B = 0
        total_num_B = 0
        correct_norm_B = 0
        loss_A = 0
        loss_B = 0
        correct_A_ours = 0
        total_num_A_ours = 0
        correct_norm_A_ours = 0
        correct_B_ours = 0
        total_num_B_ours = 0
        correct_norm_B_ours = 0
        # f_score1 = open('task1_score.txt','a', encoding='utf-8')
        for output in validation_step_outputs:
            predictions, answers = output[0]
            for pred, ans in zip(predictions, answers):
                total_num_A+=1
                if pred == "":
                    continue
                if pred in ans[0]:
                    correct_A+=1
                if get_exact_match(pred,ans[0]) == True:
                    correct_norm_A += 1
            
            predictions, answers = output[1]
            
            
            count = 0
            for pred, ans in zip(predictions, answers):
                count += 1
                total_num_B+=1
                if pred == "":
                    continue
                
                if pred in ans[0]:
                    correct_B+=1

                if get_exact_match(pred,ans[0]) == True:
                    correct_norm_B += 1
        
            
        self.avg_accs.append((correct_norm_A + correct_norm_B) / (total_num_B + total_num_A))
        self.log("task_A_acc", correct_norm_A/total_num_A, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.log("task_B_acc",correct_norm_B/total_num_B, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.log("task_avg_acc", (correct_norm_A*correct_norm_B)/(correct_norm_A+correct_norm_B), on_step = False, on_epoch = True, prog_bar = True, sync_dist = True)
        with open('performance_only+lora_zsre.txt', 'a', encoding='utf-8') as f_w:
            print(str(correct_norm_A/total_num_A)+'\t'+str(correct_norm_B/total_num_B), file=f_w)
        
    def test_step(self, batch, batch_idx):
        output = []
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]


        task_A = torch.tensor([1 for j in range(len(batch['input_ids']))],device=self.device)
        loss, pred, _ = self.model.expert_prepare(input_ids, attention_mask, labels, expert_labels=task_A)

        norm = pred.norm(p=2, dim=1, keepdim = True)
        pred = pred.div(norm)
        if self.embedding_memory is not None:
            score = torch.matmul(pred, self.embedding_memory.to(pred.device).transpose(1,0))
            score, index = torch.max(score, dim=1)
            task_A = torch.where(score>=0.9, 1, 0)
            # if i == 1:
            #     assert not (0 in task_A), "SOMETHING WRONG 0 CAN NOT BE IN TASK_A"
            generated_ids = self.model.generate(
            input_ids = input_ids,
            attention_mask = attention_mask,
            num_beams=1,    
            max_length=max_target_len,
            repetition_penalty=2.5,
            length_penalty=1.0,
            early_stopping=True,
            use_cache=True,
            task_A=task_A,
            )
        predictions = tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        output.append((input_ids, predictions, batch['aliases']))
        return output

    def test_epoch_end(self, test_step_outputs):
        f_w = open('predictions.txt','a',encoding='utf-8')

        correct_A=0
        correct_normal=0
        total_num_A = 0
        correct_f1 = 0
        for output in tqdm(test_step_outputs):
            input_ids, predictions, answers = output[0]

            for pred, ans in zip(predictions, answers):
                total_num_A+=1
                if pred =="":
                    continue
                if pred in ans[0]:
                    correct_A+=1

                if get_exact_match(pred, ans[0]) == True:
                    correct_normal+=1
                    correct = 1
                else:
                    correct = 0

                correct_f1 += get_f1_match(pred, ans[0])
                print(pred + '\t' + "!@#".join(ans[0]) + '\t'+str(correct), file=f_w)

        self.log("task_A_acc", correct_A/total_num_A, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.log("task_A_acc normalize", correct_normal/total_num_A, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.log("task_A_acc F1", correct_f1/total_num_A, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)

        print("Task A acc:", correct_A/total_num_A)
        print("Task A acc normalize:", correct_normal/total_num_A)
        print("TOTAL A", total_num_A)
        print("CORRECT_A", correct_normal)
    def set_grouped_parameters(self,grouped_parameters):
        self.grouped_parameters=grouped_parameters

    def configure_optimizers(self):
        if self.grouped_parameters is None:
            return AdamW(self.parameters(), lr=5e-5)
        else:
            print("Wanted Optimizer")

            return AdamW(self.grouped_parameters, betas=(0.9, 0.999), eps=1e-6,
                      correct_bias=True)


def generate_answer(example, trained_model,task_A=[0]):
    source_encoding = tokenizer(
        example["question"],
        # question["context"],
        max_length=max_source_len,
        padding="max_length",
        truncation="only_first",
        return_attention_mask=True,
        add_special_tokens=True,
        return_tensors="pt"
    )

    generated_ids = trained_model.model.generate(
        input_ids = source_encoding["input_ids"].to(trained_model.device),
        attention_mask = source_encoding["attention_mask"].to(trained_model.device),
        num_beams=1,
        max_length=max_target_len,
        repetition_penalty=2.5,
        length_penalty=1.0,
        early_stopping=True,
        use_cache=True,
        task_A=task_A
    )
    # print("GENERATED IDS")
    # print(generated_ids)
    # # return generated_ids
    preds = [
       tokenizer.decode(generated_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
       for generated_id in generated_ids
    ]

    return "".join(preds)


def f1(truths, preds):
    return mean([compute_f1(truth, pred) for truth, pred in zip(truths, preds)])

def exact(truths, preds):
    return mean([compute_exact(truth, pred) for truth, pred in zip(truths, preds)])



if __name__ == "__main__":
    path = 'new_dataset/zeroshot_train_B_4to4.json'
    with open(path, 'r') as f:
        train_dataset = json.load(f)
        size = len(train_dataset)
        print(size)
        f.close()

    path = 'new_dataset/zeroshot_train_A.json'
    with open(path, 'r') as f:
        dev_dataset_taskA = json.load(f)
        size = len(dev_dataset_taskA)
        print(size)
        f.close()

    path = 'new_dataset/zeroshot_train_B_4to4.json'
    with open(path, 'r') as f:
        dev_dataset_taskB = json.load(f)
        size = len(dev_dataset_taskB)
        print(size)
        # quit()
        f.close()

    train_df = extract_questions_and_answers(train_dataset)
    val_taskA_df = extract_questions_and_answers(dev_dataset_taskA)
    val_taskB_df = extract_questions_and_answers(dev_dataset_taskB)

    num_of_val_samples = len(val_taskB_df)

    task_a_list=val_taskA_df.sample(n=1000, random_state=42, replace=False)

    task_b_list=val_taskB_df.sample(n=1000, random_state=42, replace=False)

    data_module = QADataModule(train_df, task_a_list, val_taskB_df, tokenizer, batch_size=BATCH_SIZE)

    data_module.setup()
    logger = pl.loggers.TensorBoardLogger("zsREmodels/OURS_LoRA_zsRE_with_4to2", name=None)

    checkpoint_callback = ModelCheckpoint(
        dirpath = os.path.join(logger.log_dir,'FFFinal_version'),
        filename = "model-{epoch:02d}-{task_A_acc:.3f}-{task_B_acc:.3f}",
        save_top_k = 3,
        verbose = True,
        monitor = "task_avg_acc",
        mode = "max",
    )


    trainer = pl.Trainer(
        callbacks=[checkpoint_callback],
        # checkpoint_callback=checkpoint_callback,
        max_epochs=N_EPOCHS,
        gpus=N_GPUS,
        # progress_bar_refresh_rate = 30
        accelerator="ddp",
        check_val_every_n_epoch=4, ##frequency to check val
        plugins=DDPPlugin(find_unused_parameters=False)
        # check_test_every_n_epoch=1
    )

    if mode == 'train':

        checkpoint_path = "zsRE_models/pretrain/best-checkpoint.ckpt"
        model = QAModel.load_from_checkpoint(checkpoint_path, strict = False)
        # model = QAModel()
        model.freeze()
        count=0
        no_decay=['bias','layer_norm.weight']

        for n,p in model.named_parameters():
            if  'lora' in n  and '.0.weight' not in n:
                p.requires_grad = True
        
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if ('lora' in n) and not any(nd in n for nd in no_decay)],
                'lr': 4e-4,
                'weight_decay':0.01
            },
            {
                "params": [p for n, p in model.named_parameters() if ('lora' in n)  and  any(nd in n for nd in no_decay)],
                'lr': 4e-4,
                'weight_decay':0
            }
        ]

        model.set_grouped_parameters(optimizer_grouped_parameters)
        model.train()

        trainer.fit(model, data_module)
        with open("embedding_memory_with_4to2","wb") as f:
            pickle.dump(model.embedding_memory,f)

        with open("task_A_accs_4e-4_our_zsRELoRA_with_new_train","wb") as f:
            pickle.dump(model.task_A_accs,f)

        with open("task_B_accs_4e-4_our_zsRELoRA_with_new_train","wb") as f:
            pickle.dump(model.task_B_accs,f)

        with open("task_avg_accs_4e-4_our_LoRA_with_new_train","wb") as f:
            pickle.dump(model.avg_accs,f)





    elif mode == 'param comparison':
        checkpoint_path = "models/REC/best-checkpoint-v5.ckpt"
        model = QAModel.load_from_checkpoint(checkpoint_path)
        orig_model = QAModel()
        params_dict = {n:p for n,p in orig_model.named_parameters()}
        count = 0
        for n,p in model.named_parameters():
            if not torch.equal(p, params_dict.get(n)):
                print(n , "is different")
                count+=1
        print(count)

    elif mode == 'generation_test':
        checkpoint_path = "zsREmodels/OURS_LoRA_zsRE_with_4to4/version_15/QV/model-epoch=59-task_A_acc=0.958-task_B_acc=0.943.ckpt"
        model = QAModel.load_from_checkpoint(checkpoint_path)
        with open("embedding_memory_with_4to2", "rb") as f:
            model.embedding_memory = pickle.load(f)
        
        # data = data_module.test_dataset[96]
        for idx,data in enumerate(data_module.test_dataset):
            
            # pred = generate_answer(data,model,task_A=[1])
            

            # if normalize_answer(pred) != normalize_answer(data['answer_text']):
            # print(idx)
            # print(pred)
            print(data['answer_text'])
            print("LABELS")
            print(data['labels']) 
            loss, logits, _ = model(input_ids = data['input_ids'].view(1,-1), attention_mask = data['attention_mask'].view(1,-1), labels = data['labels'].view(1,-1), expert_labels = torch.tensor([1],device = model.device))
            print(data['labels'])
            print(loss)
            print(logits[0][1])
            print(torch.argmax(logits,dim = -1))
            quit()
        # print("LOGITS")
        # print(torch.argmax(logits,dim = 2)[0])
        # print(loss)
        # quit()

        # print(torch.argmax(logits))

    else:
        checkpoint_path = "zsREmodels/OURS_LoRA_zsRE_with_4to2/version_1/FFFinal_version/model-epoch=27-task_A_acc=0.915-task_B_acc=0.954.ckpt"
        model = QAModel.load_from_checkpoint(checkpoint_path)
        with open("distribution_folder/embedding_memory_with_4to4_for_distribution", "rb") as f:
            model.embedding_memory = pickle.load(f)
        
        # model.test()
        # model = QAModel()
        trainer.test(model, data_module.test_dataloader())
