import argparse
import os

import torch
import numpy as np
import math
import random
from tqdm import tqdm
from str2bool import str2bool
from tqdm import tqdm
import h5py
import json

from datetime import datetime

from transformers import AdamW
from transformers.optimization import get_linear_schedule_with_warmup
from torch.utils.data import Dataset
from metrics import f1_metric, bleu_metric, distinct_metric, rouge_metric, entropy_metric
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
from torch.cuda.amp import autocast
from torch.cuda.amp import GradScaler
from utils import occupy_mem_new, save_hparams, get_batch_loader, init_para_frompretrained


END_OF_TEXT = '<|endoftext|>'


class RedditDataset(Dataset):
    def __init__(self, data_path):
        reader = h5py.File(data_path, 'r')
        self._all_dialogs = reader['dialogs']
        self._n_data = len(self._all_dialogs)

    def __len__(self):
        return self._n_data

    def __getitem__(self, i):
        dialog = self._all_dialogs[i]
        utterances = dialog.split('[SEP]')
        utterances = [u.strip() for u in utterances]
        context = END_OF_TEXT.join(utterances[:-1])
        response = utterances[-1]
        return context, response

    @staticmethod
    def collate_fn(batch):
        context_list = [item[0] for item in batch]
        response_list = [item[1] for item in batch]
        return context_list, response_list


class DailyDialogDataset(Dataset):
    def __init__(self, data_path):
        self._context = []
        self._response = []
        with open(data_path, 'r', encoding='utf-8') as f:
            for line in f.readlines():
                example = json.loads(line)
                self._context.append(END_OF_TEXT.join(example['context']))
                self._response.append(example['response'])
        self._n_data = len(self._context)

    def __len__(self):
        return self._n_data

    def __getitem__(self, i):
        context = self._context[i]  # "U1 <#Q#> U2 <#Q#> U3"
        response = self._response[i]
        return context, response

    @staticmethod
    def collate_fn(batch):
        context_list = [item[0] for item in batch]
        response_list = [item[1] for item in batch]
        return context_list, response_list


class DailyDialogBatcher(object):
    def __init__(self, max_utterance_num, max_utterance_len, gpt2_config, only_update_last, cuda=True):
        self.max_utterance_num = max_utterance_num
        self.max_utterance_len = max_utterance_len
        self.only_update_last = only_update_last

        self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_config)

        self.bos_id = self.tokenizer.bos_token_id
        self.eos_id = self.tokenizer.eos_token_id
        self.pad_id = 0

        self.device = torch.device('cuda' if cuda else 'cpu')

    def tokenize(self, text, offset=0):
        results         = self.tokenizer(text, add_special_tokens=True)
        input_ids       = [self.bos_id] + results["input_ids"]
        labels          = results["input_ids"] + [self.eos_id]
        attention_mask  = [1] + results["attention_mask"]

        input_ids       = input_ids[:self.max_utterance_len]
        labels          = labels[:self.max_utterance_len]
        attention_mask  = attention_mask[:self.max_utterance_len]
        position_ids    = list(range(offset, len(input_ids) + offset))
        newoffset       = len(input_ids) + offset

        input_ids       = input_ids + [self.pad_id] * (self.max_utterance_len - len(input_ids))
        labels          = labels + [-1] * (self.max_utterance_len - len(labels))
        attention_mask  = attention_mask + [0] * (self.max_utterance_len - len(attention_mask))
        position_ids    = position_ids + [position_ids[-1]] * (self.max_utterance_len - len(position_ids))
        return input_ids, labels, attention_mask, position_ids, newoffset

    def __call__(self, context_list, response_list=None, training=False):
        if training:
            input_ids_list, labels_list, attention_mask_list, position_ids_list = [], [], [], []
            shuffled_input_ids_list, shuffled_attention_mask_list = [], []
            for context, response in zip(context_list, response_list):
                input_ids, labels, attention_mask, position_ids = [], [], [], []
                shuffled_input_ids_0, shuffled_attention_mask_0, shuffled_input_ids_1, shuffled_attention_mask_1 = [], [], [], []
                utterances = context.split(END_OF_TEXT)[-(self.max_utterance_num - 1):] + [response]
                offset = 0
                for _ in range(self.max_utterance_num - len(utterances)):  # for batch decoding
                    input_ids.append([self.pad_id] * self.max_utterance_len)
                    labels.append([-1] * self.max_utterance_len)
                    attention_mask.append([0] * self.max_utterance_len)
                    position_ids.append([0] * self.max_utterance_len)
                    shuffled_input_ids_0.append([self.pad_id] * self.max_utterance_len)
                    shuffled_attention_mask_0.append([0] * self.max_utterance_len)
                for i, u in enumerate(utterances):
                    ids, label, mask, pos, offset = self.tokenize(u.strip(), offset)
                    if self.only_update_last and i != len(utterances) - 1:
                        label = [-1] * len(label)
                    input_ids.append(ids)
                    labels.append(label)
                    attention_mask.append(mask)
                    position_ids.append(pos)
                    shuffled_input_ids_1.append(ids.copy())
                    shuffled_attention_mask_1.append(mask.copy())
                input_ids_list.append(input_ids)
                labels_list.append(labels)
                attention_mask_list.append(attention_mask)
                position_ids_list.append(position_ids)
                shuffled_ids = list(range(len(shuffled_input_ids_1)))
                random.shuffle(shuffled_ids)
                # print(shuffled_ids)
                # input(">>>")
                shuffled_input_ids = shuffled_input_ids_0 + [shuffled_input_ids_1[i] for i in shuffled_ids]
                shuffled_attention_mask = shuffled_attention_mask_0 + [shuffled_attention_mask_1[i] for i in shuffled_ids]
                shuffled_input_ids_list.append(shuffled_input_ids)
                shuffled_attention_mask_list.append(shuffled_attention_mask)


            input_ids_list = torch.tensor(input_ids_list, device=self.device, dtype=torch.long)  # [batch, max_utterance_num, max_utterance_len]
            labels_list = torch.tensor(labels_list, device=self.device, dtype=torch.long)
            attention_mask_list = torch.tensor(attention_mask_list, device=self.device, dtype=torch.float)  # [batch, max_utterance_num, max_utterance_len]
            position_ids_list = torch.tensor(position_ids_list, device=self.device, dtype=torch.long)
            shuffled_input_ids_list = torch.tensor(shuffled_input_ids_list, device=self.device, dtype=torch.long)
            shuffled_attention_mask_list = torch.tensor(shuffled_attention_mask_list, device=self.device, dtype=torch.float)
            return {
                "input_ids": input_ids_list.view(-1, self.max_utterance_num * self.max_utterance_len),
                "labels": labels_list.view(-1, self.max_utterance_num * self.max_utterance_len).contiguous(),
                "attention_mask": attention_mask_list.view(-1, self.max_utterance_num * self.max_utterance_len),
                "position_ids": position_ids_list.view(-1, self.max_utterance_num * self.max_utterance_len),
                "shuffled_input_ids": shuffled_input_ids_list.view(-1, self.max_utterance_num * self.max_utterance_len),
                "shuffled_attention_mask": shuffled_attention_mask_list.view(-1, self.max_utterance_num * self.max_utterance_len),
            }
        else:
            input_ids_list, attention_mask_list, position_ids_list, token_type_ids_list = [], [], [], []
            for context in context_list:
                input_ids, attention_mask, position_ids, token_type_ids = [], [], [], []
                utterances = context.split(END_OF_TEXT)[-(self.max_utterance_num - 1):]
                offset = 0
                for _ in range(self.max_utterance_num - 1 - len(utterances)):  # for batch decoding
                    input_ids.append([self.pad_id] * self.max_utterance_len)
                    attention_mask.append([0] * self.max_utterance_len)
                    position_ids.append([0] * self.max_utterance_len)
                for i, u in enumerate(utterances):
                    ids, label, mask, pos, offset = self.tokenize(u.strip(), offset)
                    input_ids.append(ids)
                    attention_mask.append(mask)
                    position_ids.append(pos)

                input_ids_list.append(input_ids)
                attention_mask_list.append(attention_mask)
                position_ids_list.append(position_ids)

            input_ids_list = torch.tensor(input_ids_list, device=self.device, dtype=torch.long)  # [batch, max_utterance_num, max_utterance_len]
            attention_mask_list = torch.tensor(attention_mask_list, device=self.device, dtype=torch.float)  # [batch, max_utterance_num, max_utterance_len]
            position_ids_list = torch.tensor(position_ids_list, device=self.device, dtype=torch.long)

            input_ids = input_ids_list.view(-1, (self.max_utterance_num - 1) * self.max_utterance_len)
            input_ids = torch.cat([input_ids, input_ids.new_ones((input_ids.shape[0], 1)) * self.bos_id], dim=1)
            attention_mask = attention_mask_list.view(-1, (self.max_utterance_num - 1) * self.max_utterance_len)
            attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
            position_ids = position_ids_list.view(-1, (self.max_utterance_num - 1) * self.max_utterance_len)
            position_ids = torch.cat([position_ids, torch.add(position_ids[:, -1:], position_ids.new_ones((position_ids.shape[0], 1)))], dim=-1)

            return {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "position_ids": position_ids,
            }


def main(args):
    print("\nParameters:")
    for attr, value in sorted(vars(args).items()):
        print("{}={}".format(attr.upper(), value))
    print("")

    # Selecting which GPU to use
    occupy_mem_new(args.gpu_list.split(','), ratio=args.gpu_ratio, num_devices=args.n_device)

    args.cuda = torch.cuda.is_available() and not args.no_cuda

    # Output directory for models and summaries
    out_dir = os.path.join(args.log, args.exp_name)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    print('Writing to {}\n'.format(out_dir))
    save_hparams(args, os.path.join(out_dir, 'hparams'))

    # Checkpoint directory
    checkpoint_dir = os.path.join(out_dir, 'checkpoints')
    checkpoint_prefix = os.path.join(checkpoint_dir, 'model')
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    # Build dataset
    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print("Create training dataset begin... | %s " % time_str)

    train_dataset = DailyDialogDataset(args.train_file)
    valid_dataset = DailyDialogDataset(args.valid_file)

    train_loader = get_batch_loader(train_dataset, collate_fn=DailyDialogDataset.collate_fn, batch_size=args.batch_size, is_test=False)
    valid_loader = get_batch_loader(valid_dataset, collate_fn=DailyDialogDataset.collate_fn, batch_size=args.eval_batch_size, is_test=True)

    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print("Create training dataset end... | %s " % time_str)

    batcher = DailyDialogBatcher(
        max_utterance_num=args.max_utterance_num,
        max_utterance_len=args.max_utterance_len,
        gpt2_config=args.gpt2_cache_dir,
        only_update_last=False,
        cuda=args.cuda,
    )

    # if args.model_type == "v1":
    from models.modeling_vae import VAEModel

    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print("Initialize parameters from pretrained begin... | %s " % time_str)
    # gpt2_model = GPT2LMHeadModel.from_pretrained(args.gpt2_cache_dir)
    gpt2_config = GPT2Config.from_pretrained(args.gpt2_cache_dir)
    model = VAEModel(
        config=gpt2_config,
        max_utterance_num=args.max_utterance_num,
        max_utterance_len=args.max_utterance_len,
        num_category=args.num_category,
        add_input=args.add_input,
        add_attn=args.add_attn,
        add_softmax=args.add_softmax,
        attn_proj_vary=args.attn_proj_vary,
        learn_prior=args.learn_prior,
    )
    model.load_state_dict(torch.load('{}/pytorch_model.bin'.format(args.pretrain_file)), strict=True)

    time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    print("Initialize parameters from pretrained end... | %s " % time_str)
    
    tuning_all = False
    model.fix_pretrained_parameters()

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = torch.nn.DataParallel(model)

    if args.cuda:
        model.cuda()


    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon)
    total_steps = args.num_epochs * (len(train_dataset) / (args.batch_size * args.accum_steps))
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=total_steps)

    beta = args.beta_0
    LAMBDA = args.LAMBDA_0

    def train_step(global_step, beta=1.0, LAMBDA=1.0):
        lm_loss_total = 0.0
        sensitive_kl_loss_total = 0.0
        independent_kl_loss_total = 0.0
        category_kl_loss_total = 0.0
        SCC_loss_total = 0.0
        DFP_loss_total = 0.0
        MI_loss_total = 0.0

        for _ in range(args.accum_steps):
            context_list, response_list = next(train_loader)

            model.train()
            fwd_args = batcher(context_list, response_list, training=True)
            output = model(return_dict=True, **fwd_args)
            lm_loss, sensitive_kl_loss, independent_kl_loss, category_kl_loss = output['lm_loss'], output['context_sensitive_kl_loss'], output['context_independent_kl_loss'], output['category_kl_loss']
            SCC_loss, DFP_loss, MI_loss = output['SCC_loss'], output['DFP_loss'], output['MI_loss']
            if torch.cuda.device_count() > 1:
                lm_loss = lm_loss.mean()
                sensitive_kl_loss = sensitive_kl_loss.mean()
                independent_kl_loss = independent_kl_loss.mean()
                category_kl_loss = category_kl_loss.mean()
                SCC_loss = SCC_loss.mean()
                DFP_loss = DFP_loss.mean()
                MI_loss = MI_loss.mean()
            lm_loss = lm_loss / args.accum_steps
            sensitive_kl_loss = sensitive_kl_loss / args.accum_steps
            independent_kl_loss = independent_kl_loss / args.accum_steps
            category_kl_loss = category_kl_loss / args.accum_steps
            SCC_loss = SCC_loss / args.accum_steps
            DFP_loss = DFP_loss / args.accum_steps
            MI_loss = MI_loss / args.accum_steps
            loss = lm_loss + beta * (sensitive_kl_loss + independent_kl_loss + category_kl_loss) + LAMBDA * (SCC_loss + DFP_loss + MI_loss)
            loss.backward()

            lm_loss_total += lm_loss.item()
            sensitive_kl_loss_total += sensitive_kl_loss.item()
            independent_kl_loss_total += independent_kl_loss.item()
            category_kl_loss_total += category_kl_loss.item()
            SCC_loss_total += SCC_loss.item()
            DFP_loss_total += DFP_loss.item()
            MI_loss_total += MI_loss.item()

        grad_norm = torch.nn.utils.clip_grad_norm_([p for p in model.parameters() if p.requires_grad], args.clip)
        if grad_norm >= 1e2:
            print('WARNING : Exploding Gradients {:.2f}'.format(grad_norm))
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        if global_step % args.print_every == 0 and global_step != 0:
            time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
            print("Step: %d \t| lm_loss: %.3f \t| kl_s: %.3f \t| kl_i: %.3f \t| kl_c: %.3f \t| SCC: %.3f \t| DFP: %.3f \t| MI: %.3f \t| lr: %.8f \t| %s" % (
                global_step, lm_loss_total, sensitive_kl_loss_total, independent_kl_loss_total, category_kl_loss_total, SCC_loss_total, DFP_loss_total, MI_loss_total, scheduler.get_lr()[0], time_str))

    def test_step(split, global_step):
        model.eval()

        test_loss = []
        model.train()
        with torch.no_grad():
            for context_list, response_list in tqdm(valid_loader):
                fwd_args = batcher(context_list, response_list, training=True)
                loss = model(return_dict=True, **fwd_args)['lm_loss']
                test_loss.append(loss.item())
        model.eval()
        print("**********************************")
        print("{} results..........".format(split))
        print('losses: ', len(test_loss))
        print("loss: {:.4f}".format(np.mean(test_loss)))
        print("ppl: {:.4f}".format(math.exp(np.mean(test_loss))))
        print("**********************************")

        test_hyp, test_ref = [], []
        model_to_generate = model.module if hasattr(model, "module") else model
        with torch.no_grad():
            for context_list, response_list in tqdm(valid_loader):
                fwd_args = batcher(context_list, training=False)
                fwd_args["max_length"] = fwd_args["input_ids"].size(1) + args.max_length
                fwd_args["min_length"] = fwd_args["input_ids"].size(1) + args.min_length
                fwd_args["do_sample"] = False
                fwd_args["num_beams"] = args.num_beams
                fwd_args["pad_token_id"] = batcher.eos_id
                fwd_args["eos_token_id"] = batcher.eos_id
                fwd_args["repetition_penalty"] = args.repetition_penalty
                fwd_args["use_cache"] = True

                dec_out_list = model_to_generate.generate(**fwd_args)
                dec_out_list = dec_out_list[:, fwd_args["input_ids"].size(1):].tolist()
                for dec_out, response in zip(dec_out_list, response_list):
                    hyp = batcher.tokenizer.decode(dec_out, skip_special_tokens=True, clean_up_tokenization_spaces=False)
                    test_hyp.append(hyp)
                    test_ref.append(response)
                # if len(test_hyp) > 100: break # todo: del

        with open(os.path.join(out_dir, '{}-decoded-iter-{}.txt'.format('test', global_step)), 'w', encoding='utf-8') as f:
            for _hyp, _ref in zip(test_hyp, test_ref):
                f.writelines("{} ||| {}\n".format(_hyp, _ref))

        # f1 = f1_metric(test_hyp, test_ref)
        b1, b2, b3, b4 = bleu_metric(test_hyp, test_ref)
        r1, r2, rl = rouge_metric(test_hyp, test_ref)

        d1, d2 = distinct_metric(test_hyp)
        e1, e2, e3, e4 = entropy_metric(test_hyp)
        time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        print("**********************************")
        print("{} results..........".format(split))
        print('hypothesis: ', len(test_hyp))
        print("Step: %d \t|  %s" % (global_step, time_str))
        print("BLEU-1/2/3/4: {:.4f}/{:.4f}/{:.4f}/{:.4f}".format(b1, b2, b3, b4))
        print("Rouge-1/2/L: {:.4f}/{:.4f}/{:.4f}".format(r1, r2, rl))
        print("Distinct-1/2: {:.4f}/{:.4f}".format(d1, d2))
        print("Entropy-1/2/3/4: {:.4f}/{:.4f}/{:.4f}/{:.4f}".format(e1, e2, e3, e4))
        print("**********************************")

        return {"f1": r1}

    best_f1 = 0.
    for i in range(args.num_steps):

        if (i + 1) % args.cycle > args.cycle // 2:
            beta = min(1.0, beta + (1. - args.beta_0) / (args.cycle // 4))

        if not tuning_all and (i + 1) > args.tuning_all_after_iters:
            model_to_tune = model.module if hasattr(model, "module") else model
            model_to_tune.tune_all_parameters()
            tuning_all = True

        train_step(i + 1, beta, LAMBDA)

        if (i + 1) % args.cycle == 0:
            beta = args.beta_0

        if (i + 1) % args.valid_every == 0:
            valid_results = test_step("test", i + 1)

            if valid_results["f1"] > best_f1:
                best_f1 = valid_results["f1"]

                save_path = '{}-best'.format(checkpoint_prefix)
                os.makedirs(save_path, exist_ok=True)
                model_to_save = model.module if hasattr(model, "module") else model
                model_to_save.save_pretrained(save_path)

                print("Saved model checkpoint to {}\n".format(save_path))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Pre-training for Knowledge-Grounded Conversation'
    )

    # files
    parser.add_argument('--train_file', type=str, default='')
    parser.add_argument('--valid_file', type=str, default='')

    # training scheme
    parser.add_argument('--batch_size', type=int, default=2)
    parser.add_argument('--eval_batch_size', type=int, default=2)
    parser.add_argument('--num_steps', type=int, default=1000000)
    parser.add_argument('--accum_steps', type=int, default=32)
    parser.add_argument('--lr', type=float, default=5e-5)
    parser.add_argument('--clip', type=float, default=2.0)
    parser.add_argument('--model_type', type=str, default='v1')

    parser.add_argument('--weight_decay', type=float, default=0.01)
    parser.add_argument('--adam_epsilon', type=float, default=1e-8)
    parser.add_argument('--warmup_steps', type=int, default=5000)
    parser.add_argument('--num_epochs', type=int, default=3)
    parser.add_argument('--tuning_all_after_iters', type=int, default=1000)
    parser.add_argument('--num_category', type=int, default=100)

    parser.add_argument('--print_every', type=int, default=10)
    parser.add_argument('--valid_every', type=int, default=1)

    # save
    parser.add_argument('--exp_name', type=str, default='0601_test')
    parser.add_argument('--log', type=str, default='wizard_of_wikipedia/log')
    parser.add_argument('--seed', type=int, default=42)

    # model
    parser.add_argument('--gpt2_cache_dir', type=str, default="/home2/zhaoxl/Data/pretrain-models/bert-base-uncased")
    parser.add_argument('--pretrain_file', type=str, default='')
    parser.add_argument('--max_utterance_len', type=int, default=64)
    parser.add_argument('--max_utterance_num', type=int, default=8)
    parser.add_argument('--add_input', type=str2bool, default=False)
    parser.add_argument('--add_attn', type=str2bool, default=False)
    parser.add_argument('--add_softmax', type=str2bool, default=False)
    parser.add_argument('--attn_proj_vary', type=str2bool, default=False)
    parser.add_argument('--learn_prior', type=str2bool, default=False)

    parser.add_argument('--beta_0', type=float, default=1.0)
    parser.add_argument('--cycle', type=int, default=10000)
    parser.add_argument('--LAMBDA_0', type=float, default=1.0)

    # parser.add_argument('--only_update_last', type=str2bool, default=False)  # for grounded dialogue
    parser.add_argument('--max_length', type=int, default=32)
    parser.add_argument('--min_length', type=int, default=8)
    parser.add_argument('--num_beams', type=int, default=5)
    parser.add_argument('--repetition_penalty', type=float, default=1.0)

    # gpu
    parser.add_argument('--gpu_list', type=str, default='0')
    parser.add_argument('--gpu_ratio', type=float, default=0.85)
    parser.add_argument('--n_device', type=int, default=7)
    parser.add_argument('--no_cuda', type=str2bool, default=False)

    args = parser.parse_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    main(args)
