from config import STATE_DICT_KEY, OPTIMIZER_STATE_DICT_KEY
from .verb import ManualVerbalizer
from .utils import *
from .loggers import *
from .base import *
from utils import init_scaler, init_optimizer, init_scheduler

import re
import os.path as osp
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import json
import numpy as np
from abc import *
from pathlib import Path

import bitsandbytes as bnb
from transformers.trainer import *
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback
from peft import PeftModel


# def compute_metrics_for_ks(ks, verbalizer):
#     def compute_metrics(eval_pred):
#         logits, labels = eval_pred
#         logits = torch.tensor(logits)
#         labels = torch.tensor(labels).view(-1)
#         scores = verbalizer.process_logits(logits)
#         metrics = absolute_recall_mrr_ndcg_for_ks(scores, labels, ks)
#         return metrics
#     return compute_metrics
def compute_metrics_for_kp():
    def compute_metrics(pred):
        labels = pred.label_ids
        bos_token = 1
        precision, recall, f1, acc = [], [], [], []
        candidate_ids = [2, 319, 350, 315, 360, 382, 383, 402, 379, 306, 435, 476, 365, 341, 405, 438, 349, 660, 390, 317, 323, 501, 478, 399, 1060, 612, 796]
        
        # 将preds中最后一个维度，不在candidate_ids中的列，置零
        preds = pred.predictions[..., candidate_ids].argmax(-1)
        # 将index替换为candidate_ids中的值
        preds = np.take(candidate_ids, preds)
        for label, p in zip(labels, preds):
            # eos_token在label中的末次出现位置
            bos_idx = np.where(label == bos_token)[0][-1]
            label = label[bos_idx+1:-1]
            # eos_end_idx = np.where(pred == bos_token)[-1][0]
            p = p[bos_idx:-2]

            pred_number = len(p)
            # pred与label的交集
            pred_true_number = len(set(p).intersection(set(label)))
            label_number = len(label)

            if pred_number > 0:
                prec = pred_true_number / pred_number
            else:
                prec = 0
            precision.append(prec)
            if label_number > 0:
                rec = pred_true_number / label_number
            else:
                rec = 0
            recall.append(rec)

            if prec + rec > 0:
                f1.append(2 * prec * rec / (prec + rec))
            else:
                f1.append(0)
            # candidate_number = 
        
        return {
            'f1': np.mean(f1),
            'precision': np.mean(precision),
            'recall': np.mean(recall),
        }

    return compute_metrics

    

class LLMTrainer():
    def __init__(
            self,
            args,
            model,
            train_loader,
            val_loader,
            test_loader,
            tokenizer,
            config,
            export_root,
            use_wandb,
            log,
            **kwargs
        ):
        self.args = args
        self.log = log
        self.ignore_index = -100
        self.loss_fn = nn.CrossEntropyLoss(ignore_index=self.ignore_index, reduction='none')

        self.original_args = args
        self.export_root = export_root
        self.use_wandb = use_wandb
        self.llm_max_text_len = args.llm_max_text_len

        self.train_loader = train_loader
        self.train_sampler = train_loader.sampler
        self.val_loader = val_loader
        self.val_sampler = val_loader.sampler
        # self.test_loader = test_loader
        self.tokenizer = tokenizer
        self.config = config
        self.valid_label_ids = [self.tokenizer('T', add_special_tokens=False)['input_ids'][0], 
                                self.tokenizer('F', add_special_tokens=False)['input_ids'][0]]
        # self.test_loader.collate_fn = llama_collate_fn_w_truncation(self.llm_max_text_len, eval=True)
        # self.compute_metrics = compute_metrics_for_kp()
        
        self.model = model

        ### Calculate steps ###
        args.total_steps = int(len(self.train_loader) * args.lora_num_epochs // args.gradient_accumulation_steps)
        # args.warmup_steps = int(args.warmup_steps)
        log.console(f"warmup steps: {args.warmup_steps}, total steps: {args.total_steps}")

        ### scaler / optimizer / scheduler ###
        self.scaler = init_scaler(args)
        self.optimizer = init_optimizer(args, self.model)
        self.scheduler = init_scheduler(args, self.optimizer)

        self.best_valid_loss = float("inf")
        self.best_valid_f1 = float("-inf")
        self.start_epoch = 0
        self.tolerance = 0
        self.global_step = 0

        ### Resume training ###
        ckpt_model_path = osp.join(self.export_root, "best-f1")
        if args.resume and osp.exists(ckpt_model_path):
            log.console(f"Resuming model checkpoint from {ckpt_model_path}...")
            # self.model = PeftModel.from_pretrained(model, ckpt_model_path)
            # load optimizer and scheduler
            self.optimizer.load_state_dict(torch.load(osp.join(ckpt_model_path, "optimizer.pt"), map_location='cpu'))
            self.scheduler.load_state_dict(torch.load(osp.join(ckpt_model_path, "scheduler.pt"), map_location='cpu'))
            # load training state
            ckpt = torch.load(osp.join(ckpt_model_path, "training_state.pt"))
            self.best_valid_loss = ckpt["loss"]
            self.best_valid_f1 = ckpt["f1"]
            self.start_epoch = ckpt["epoch"]
            self.global_step = ckpt["steps"]

            log.console(f"Validation loss was {ckpt['loss']:.4f}")

    def compute_loss(self, model, inputs, return_outputs=False):
        torch.cuda.empty_cache()
        
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
        target_mask = inputs['target_mask']
        labels = inputs['labels']
        targets_number = inputs['targets_number']
        
        # 模型前馈预测
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        logits = outputs["logits"] if isinstance(outputs, dict) else outputs[0]
        
        # only save the logit of self.valid_label_ids in logits
        valid_mask = torch.zeros_like(logits, dtype=torch.bool)
        valid_mask[..., self.valid_label_ids] = 1
        logits = torch.where(valid_mask, logits, float('-inf'))

        # 将labels中不属于target的部分，设为ignore_index，只计算target部分的loss
        labels = torch.where(target_mask == 1, labels, self.ignore_index)
        shift_logits = logits[..., :-2, :].contiguous()
        shift_labels = labels[..., 1:-1].contiguous()
        loss = self.weight_loss(shift_logits, shift_labels, targets_number)
        return (loss, outputs) if return_outputs else loss


    def weight_loss(self, shift_logits, shift_labels, targets_number):
        t_mask = (shift_labels == self.valid_label_ids[0])
        f_mask = (shift_labels == self.valid_label_ids[1])
        num_t = t_mask.sum(-1)
        num_f = f_mask.sum(-1)
        loss_all = self.loss_fn(shift_logits.transpose(1, 2), shift_labels)
        loss_t = loss_all * t_mask.float()
        loss_t = loss_t.sum(-1) / (num_t + 1e-20)
        loss_f = loss_all * f_mask.float()
        loss_f = loss_f.sum(-1) / (num_f + 1e-20)

        if self.args.weight_loss:
            pred_t = shift_logits.argmax(-1) == self.valid_label_ids[0]
            candidate_t = shift_labels == self.valid_label_ids[0]
            pred_t_num = pred_t.sum(-1)
            candidate_t_num = candidate_t.sum(-1)
            t_num = (pred_t & candidate_t).sum(-1)
            # R = t_num / (candidate_t_num + 1e-20)
            R = t_num / (targets_number + 1e-20)
            P = t_num / (pred_t_num + 1e-20)
            weight_f = R / ((P + 1e-20))
            weight_f = torch.sigmoid(weight_f)
            weight_t = 1 - weight_f
            loss = (loss_t * weight_t + loss_f * weight_f).mean()
            # if loss = nan
            if torch.isnan(loss):
                print(f"loss is nan\n weight_f: {weight_f}\n")
        else:
            loss = (loss_t + loss_f).mean()
        # loss = self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            
        return loss
    

    def train(self):
        for epoch in range(self.start_epoch, self.args.lora_num_epochs):
            if self.args.distributed:
                self.train_sampler.set_epoch(epoch)
            avg_train_loss = self.__epoch_train(epoch)
            # avg_valid_loss, valid_score_dict = self.__epoch_valid()
            if self.tolerance == self.args.max_tolerance: break
            self.log.console(f"epoch: {epoch+1}, " +
                        f"steps: {self.global_step}, " +
                        f"current lr: {self.optimizer.param_groups[0]['lr']:.8f}, " +
                        f"train loss: {avg_train_loss:.4f}")


    def __epoch_train(self, epoch):
        self.model.train()
        train_loss = 0
        total = len(self.train_loader)
        no_ext_count = 0

        with tqdm(desc="Training", total=total, ncols=100) as pbar:
            for step, inputs in enumerate(self.train_loader, 1):
                for k, v in inputs.items():
                    inputs[k] = v.cuda(self.args.gpu, non_blocking=True)

                ### Forward pass ###
                with torch.cuda.amp.autocast(enabled=self.args.use_amp):
                    outputs = self.compute_loss(self.model, inputs, return_outputs=True)
                    loss = outputs[0]

                    if self.args.gradient_accumulation_steps > 1:
                        loss = loss / self.args.gradient_accumulation_steps
                        
                train_loss += loss.item()

                ### Backward pass ###
                _step = step - no_ext_count
                # if step > 5:
                #     break
                if _step % self.args.gradient_accumulation_steps == 0:
                    self.optimizer.zero_grad()
                    self.scaler.scale(loss).backward()
                    self.scaler.unscale_(self.optimizer)
                    if self.args.max_grad_norm > 0:
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                    self.scheduler.step()
                    self.global_step += 1

                    if self.global_step == 1 or self.global_step % self.args.logging_steps == 0:
                        curr_train_loss = train_loss / (_step / self.args.gradient_accumulation_steps)
                        self.log.console(f"\ncurrent lr: {self.optimizer.param_groups[0]['lr']:.8f}, " +
                                    f"steps: {self.global_step}, " +
                                    f"train loss: {(curr_train_loss):.4f}")
                    
                    # if self.global_step % self.args.lora_val_iterations == 0 or self.global_step == 1:
                    if self.global_step % self.args.lora_val_iterations == 0:
                        avg_valid_loss, valid_score_dict = self.__epoch_valid()
                        # print(f"\snloss: {avg_valid_loss}")
                        # print(f"\ncore: {valid_score_dict}")

                        # 与其他卡上的进程同步数据，对valid_score_dict求均值
                        # convert avg_valid_loss to tensor
                        if self.args.distributed:
                            avg_valid_loss = torch.tensor(avg_valid_loss, device=self.args.device)
                            torch.distributed.all_reduce(avg_valid_loss, torch.distributed.ReduceOp.AVG)
                            # convert back
                            avg_valid_loss = avg_valid_loss.item()
                            for key in valid_score_dict:
                                valid_score_dict[key] = torch.tensor(valid_score_dict[key], device=self.args.device)
                                torch.distributed.all_reduce(valid_score_dict[key], torch.distributed.ReduceOp.AVG)
                                valid_score_dict[key] = valid_score_dict[key].item()
                        
                        
                        
                        self.log.console(f"steps: {self.global_step}, " +
                                    f"valid loss: {(avg_valid_loss):.4f}, " +
                                    f"best valid loss: {self.best_valid_loss:.4f}")
                        self.log.console(f"P: {valid_score_dict['P']:.5f}, " +
                                    f"R: {valid_score_dict['R']:.5f}, " +
                                    f"f1: {valid_score_dict['F1']:.5f}, " +
                                    f"Acc: {valid_score_dict['Acc']:.5f}, " +
                                    f"candidate_T_num: {valid_score_dict['candidate_T_num']:.5f}, " +
                                    f"pred_T_num: {valid_score_dict['pred_T_num']:.5f}, " +
                                    f"pred_num: {valid_score_dict['pred_num']:.5f}")
                        

                        if avg_valid_loss < self.best_valid_loss:
                            self.tolerance = 0
                            self.best_valid_loss = avg_valid_loss

                            if not self.args.distributed or (self.args.distributed
                                and self.args.rank % self.args.n_gpu == 0):
                                save_path = osp.join(self.export_root, "best-loss")
                                self.log.console(f"Saving lowest valid loss checkpoint to {save_path}...")
                                # save lora
                                if args.distributed:
                                    self.model.module.save_pretrained(save_path)
                                else:
                                    self.model.save_pretrained(save_path)
                                # save optimizer and scheduler
                                torch.save(self.optimizer.state_dict(), osp.join(save_path, "optimizer.pt"))
                                torch.save(self.scheduler.state_dict(), osp.join(save_path, "scheduler.pt"))
                                # save training state
                                torch.save({'epoch': epoch,
                                            'steps': self.global_step,
                                            'loss': avg_valid_loss,
                                            }, osp.join(save_path, "training_state.pt"))
                        else:
                            self.tolerance += 1
                            self.log.console(f"Valid loss does not drop, patience: {self.tolerance}/{self.args.max_tolerance}")
                            self.scheduler.step(avg_valid_loss)

                        if valid_score_dict['F1'] > self.best_valid_f1:
                            self.tolerance = 0
                            self.best_valid_f1 = valid_score_dict['F1']
                            if not self.args.distributed or (self.args.distributed
                                and self.args.rank % self.args.n_gpu == 0):
                                save_path = osp.join(self.export_root, "best-f1")
                                self.log.console(f"Saving best valid F1@M checkpoint to {save_path}...")
                                # save lora
                                if args.distributed:
                                    self.model.module.save_pretrained(save_path)
                                else:
                                    self.model.save_pretrained(save_path)
                                # save optimizer and scheduler
                                torch.save(self.optimizer.state_dict(), osp.join(save_path, "optimizer.pt"))
                                torch.save(self.scheduler.state_dict(), osp.join(save_path, "scheduler.pt"))
                                # save training state
                                torch.save({'epoch': epoch,
                                            'steps': self.global_step,
                                            'loss': avg_valid_loss,
                                            'f1': valid_score_dict['F1'],
                                            }, osp.join(save_path, "training_state.pt"))
                        # else:
                            # self.tolerance += 1
                            # self.log.console(f"F1 does not improve, patience: {self.tolerance}/{self.args.max_tolerance}")

                        # Switch back to train mode!
                        self.model.train()

                if self.tolerance == self.args.max_tolerance:
                    self.log.console(f"Has not increased for {self.tolerance} checkpoints, early stop training.")
                    break

                pbar.update(1)
                del outputs, loss

        return train_loss / (total - no_ext_count)
    
    def pad_fn(self, lst, padding=0):
        if len(lst) == 0:
            return lst
        max_len = max([x.shape[-1] for x in lst])
        if len(lst[0].shape) == 1:
            for i, x in enumerate(lst):
                lst[i] = F.pad(x, (0, max_len-x.shape[-1]), "constant", padding)
            lst = torch.stack(lst, dim=0)
        elif len(lst[0].shape) == 2:
            for i, x in enumerate(lst):
                lst[i] = F.pad(x, (0, max_len-x.shape[-1]), "constant", padding)
            lst = torch.cat(lst, dim=0)  # (batch_size, max_len)
        return lst

    @torch.no_grad()
    def __epoch_valid(self):
        torch.cuda.empty_cache()
        self.model.eval()
        valid_loss = 0
        valid_logits, valid_labels, valid_target_nums = [], [], []
        score_dict = {}
        total = len(self.val_loader)
        no_ext_count = 0

        vocab_size = self.config.vocab_size
        # self.valid_label_ids对应列为1，剩下置零
        vocab_mask = torch.zeros(vocab_size, dtype=torch.bool, device=self.args.device)
        vocab_mask[self.valid_label_ids] = 1
        with tqdm(desc="Validating", total=total, ncols=100) as pbar:
            for step, inputs in enumerate(self.val_loader, 1):
                for k, v in inputs.items():
                    inputs[k] = v.cuda(self.args.gpu, non_blocking=True)

                # ### Forward pass ###
                # outputs = self.compute_loss(self.model, inputs, return_outputs=True)
                # loss = outputs[0]
                max_target_number = inputs['candidates_number'].max()
                logits = self.eval_step(inputs, vocab_mask, max_length=max_target_number)
                # logits = logits.to('cpu')
                preds = logits.argmax(dim=-1)

                labels = inputs['labels'][:, :-1].to('cpu')
                # labels中的eos置为ignore_index
                labels = torch.where(labels == self.tokenizer.eos_token_id, self.ignore_index, labels)
                candidates_number = inputs['candidates_number'].to('cpu')
                loss = self.weight_loss(logits, labels, candidates_number)
                target_num = inputs['targets_number'].to('cpu')
                # ext_logits = outputs[1]['logits'][:, -max_target_number-2:-2, :].to('cpu')
                # ext_labels = inputs['labels'][:, -max_target_number-1:-1].to('cpu')
                # # target_mask=1之外的label置-100
                # ext_labels_mask = inputs['target_mask'][:, -max_target_number-1:-1] == 1
                # ext_labels_mask = ext_labels_mask.to('cpu')
                # ext_labels = torch.where(ext_labels_mask, ext_labels, -100)

                # ext_logits = ext_logits.masked_fill(~vocab_mask, float('-inf'))
                # preds = ext_logits.argmax(dim=-1)
                
                valid_loss += loss.item()
                valid_logits.append(preds)
                valid_labels.append(labels)
                valid_target_nums.append(target_num)

                pbar.update(1)
                del logits, preds, labels, loss
                if step > 3:
                    break
                
        _total = total - no_ext_count
        
        # self.log.console(f"steps: {self.global_step}, " +
        #             f"valid loss: {(valid_loss / _total):.4f}, " +
        #             f"best valid loss: {self.best_valid_loss:.4f}")
        
        valid_logits = self.pad_fn(valid_logits, padding=-100)
        valid_labels = self.pad_fn(valid_labels, padding=-100)
        valid_target_nums = self.pad_fn(valid_target_nums, padding=0)
        valid_target_nums = valid_target_nums.view(-1)
        score_dict = self.calculate_scores(valid_logits, valid_labels, valid_target_nums)

        # self.log.console(f"P: {score_dict['P']:.5f}, " +
        #                  f"R: {score_dict['R']:.5f}, " +
        #                  f"f1: {score_dict['F1']:.5f}, " +
        #                  f"Acc: {score_dict['Acc']:.5f}, " +
        #                  f"candidate_T_num: {score_dict['candidate_T_num']:.5f}, " +
        #                  f"pred_T_num: {score_dict['pred_T_num']:.5f}, " +
        #                  f"pred_num: {score_dict['pred_num']:.5f}")

        return valid_loss / _total, score_dict

    @torch.no_grad()
    def eval_step(self, inputs, vocab_mask, max_length):
        torch.cuda.empty_cache()
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']

        cur_len = 0
        logits = None
        past_key_values = None
        while cur_len < max_length:
            outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, 
                                past_key_values=past_key_values, use_cache=True,
                                return_dict=True)
            next_logits = outputs["logits"] [:, -1, :]
            valid_logits = next_logits.masked_fill(~vocab_mask, float('-inf'))
            next_token = valid_logits.argmax(dim=-1).unsqueeze(-1)

            past_key_values = outputs.past_key_values
            input_ids = next_token
            attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
            if logits is None:
                logits = valid_logits.unsqueeze(1).to('cpu')
            else:
                logits = torch.cat([logits, valid_logits.unsqueeze(1).to('cpu')], dim=1)
            
            cur_len += 1

        return logits
        

    def calculate_scores(self, preds, labels, targets_number):
        if preds is None or labels is None:
            return None

        score_dict = {}

        label_mask = labels != -100

        candidate_T = (labels == self.valid_label_ids[0]).sum(-1)
        candidate_T_num = candidate_T.sum()
        pred_num = ((preds == self.valid_label_ids[0]) & label_mask).sum(-1)
        pred_T = ((preds == labels) & label_mask & (labels == self.valid_label_ids[0])).sum(-1)
        pred_T_num = pred_T.sum()
        precision = pred_T / (pred_num + 1e-20)
        recall = pred_T / (targets_number + 1e-20)
        f1 = 2 * precision * recall / (precision + recall + 1e-20)
        acc = ((preds == labels) & label_mask).sum(-1) / (labels != -100).sum(-1)


        score_dict[f"P"] = precision.mean().item()
        score_dict[f"R"] = recall.mean().item()
        score_dict[f"F1"] = f1.mean().item()
        score_dict['Acc'] = acc.mean().item()
        score_dict['candidate_T_num'] = candidate_T_num.item()
        score_dict['pred_T_num'] = pred_T_num.item()
        score_dict['pred_num'] = pred_num.sum().item()


        return score_dict