from utils.IntentDataset import IntentDataset
from utils.Evaluator import EvaluatorBase
from utils.Logger import logger
from utils.commonVar import *
from utils.tools import mask_tokens, makeTrainExamples
import time
import torch
from torch.utils.data import DataLoader
import numpy as np
import copy
from sklearn.metrics import accuracy_score, r2_score
from torch.utils.tensorboard import SummaryWriter
import wandb
import pdb

##
# @brief  base class of trainer
class TrainerBase():
    def __init__(self, wandb, wandbProj, wandbConfig, wandbRunName):
        self.finished=False
        self.bestModelStateDict = None
        self.roundN = 4

        # wandb 
        self.wandb = wandb
        self.wandbProjName = wandbProj
        self.wandbConfig = wandbConfig
        self.runName = wandbRunName
        pass

    def round(self, floatNum):
        return round(floatNum, self.roundN)

    def train(self):
        raise NotImplementedError("train() is not implemented.")

    def getBestModelStateDict(self):
        return self.bestModelStateDict

##
# @brief TransferTrainer used to do transfer-training. The training is performed in a supervised manner. All available data is used fo training. By contrast, meta-training is performed by tasks. 
class TransferTrainer(TrainerBase):
    def __init__(self,
            trainingParam:dict,
            optimizer,
            dataset:IntentDataset,
            unlabeled:IntentDataset,
            valEvaluator: EvaluatorBase,
            testEvaluator:EvaluatorBase):
        super(TransferTrainer, self).__init__(trainingParam["wandb"], trainingParam["wandbProj"], trainingParam["wandbConfig"], trainingParam["wandbRunName"])
        self.epoch       = trainingParam['epoch']
        self.batch_size  = trainingParam['batch']
        self.validation  = trainingParam['validation']
        self.patience    = trainingParam['patience']
        self.tensorboard = trainingParam['tensorboard']
        self.mlm         = trainingParam['mlm']
        self.lambda_mlm  = trainingParam['lambda mlm']
        self.regression  = trainingParam['regression']

        self.dataset       = dataset
        self.unlabeled     = unlabeled
        self.optimizer     = optimizer
        self.valEvaluator  = valEvaluator
        self.testEvaluator = testEvaluator

        self.batchMonitor = trainingParam["batchMonitor"]

        self.beforeBatchNorm = trainingParam['beforeBatchNorm']
        logger.info("In trainer, beforeBatchNorm %s"%(self.beforeBatchNorm))

        if self.tensorboard:
            self.writer = SummaryWriter()

    def train(self, model, tokenizer, mode='multi-class'):
        self.bestModelStateDict = copy.deepcopy(model.state_dict())
        durationOverallTrain = 0.0
        durationOverallVal = 0.0
        valBestAcc = -1
        accumulateStep = 0

        # evaluate before training
        valAcc, valPre, valRec, valFsc = self.valEvaluator.evaluate(model, tokenizer, mode)
        logger.info('---- Before training ----')
        logger.info("ValAcc %f, Val pre %f, Val rec %f , Val Fsc %f", valAcc, valPre, valRec, valFsc)

        if mode == 'multi-class':
            labTensorData = makeTrainExamples(self.dataset.getTokList(), tokenizer, self.dataset.getLabID(), mode=mode)
        dataloader = DataLoader(labTensorData, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True)


        if self.wandb:
            run = wandb.init(project=self.wandbProjName, reinit=True)
            wandb.config.update(self.wandbConfig)
            wandb.run.name=(self.runName)

        earlyStopFlag = False
        for epoch in range(self.epoch):  # an epoch means all sampled tasks are done
            batchTrLossMLMSum = 0.0
            timeEpochStart    = time.time()

            timeMonitorWindowStart = time.time()
            batchNum = len(dataloader)
            for batchID, batch in enumerate(dataloader):
                model.train()
                # task data
                Y, ids, types, masks = batch
                X = {'input_ids':ids.to(model.device),
                        'token_type_ids':types.to(model.device),
                        'attention_mask':masks.to(model.device)}

                # forward
                logits, embeddings = model(X, returnEmbedding=True, beforeBatchNorm=self.beforeBatchNorm)
                # loss
                if self.regression:
                    lossSP = model.loss_mse(logits, Y.to(model.device))
                else:
                    # lossSP = model.loss_ce(logits, Y.to(model.device))
                    lossSP, lossCE, lossCov = model.loss_ce_covariance_featureSynthesize(logits, Y.to(model.device), embeddings)

                lossTOT = lossSP

                # backward
                self.optimizer.zero_grad()
                lossTOT.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                self.optimizer.step()

                # calculate train acc
                YTensor = Y.cpu()
                logits = logits.detach().clone()
                if torch.cuda.is_available():
                    logits = logits.cpu()
                if self.regression:
                    predictResult = torch.sigmoid(logits).numpy()
                    acc = r2_score(YTensor, predictResult)
                else:
                    logits = logits.numpy()
                    predictResult = np.argmax(logits, 1)
                    acc = accuracy_score(YTensor, predictResult)

                if (batchID % self.batchMonitor) == 0:
                    # self.batchMonitor number batch training done, collect data
                    model.eval()
                    valAcc, valPre, valRec, valFsc = self.valEvaluator.evaluate(model, tokenizer, mode)

                    # statistics
                    monitorWindowDurationTrain = self.round(time.time() - timeMonitorWindowStart)

                    # display current epoch's info
                    logger.info("---- epoch: %d/%d, batch: %d/%d, monitor window time %f ----", epoch, self.epoch, batchID, batchNum, self.round(monitorWindowDurationTrain))
                    logger.info("TrainLoss %f", lossTOT.item())
                    logger.info("valAcc %f, valPre %f, valRec %f , valFsc %f", valAcc, valPre, valRec, valFsc)
                    if self.wandb:
                        wandb.log({'trainLoss': lossTOT.item(), \
                                'valAcc': valAcc, \
                                'lossCE': lossCE, \
                                'lossCov': lossCov
                                })

                    # time
                    timeMonitorWindowStart = time.time()
                    durationOverallTrain += monitorWindowDurationTrain

                    # early stop
                    if not self.validation:
                        valAcc = -1
                    if (valAcc >= valBestAcc):   # better validation result
                        print("[INFO] Find a better model. Val acc: %f -> %f"%(valBestAcc, valAcc))
                        valBestAcc = valAcc
                        accumulateStep = 0

                        # cache current model, used for evaluation later
                        self.bestModelStateDict = copy.deepcopy(model.state_dict())
                    else:
                        accumulateStep += 1
                        if accumulateStep > self.patience/2:
                            print('[INFO] accumulateStep: ', accumulateStep)
                            if accumulateStep == self.patience:  # early stop
                                logger.info('Early stop.')
                                logger.debug("Overall training time %f", durationOverallTrain)
                                logger.debug("best_val_acc: %f", valBestAcc)
                                earlyStopFlag = True
                                break

            if earlyStopFlag:
                break

        if self.wandb:
            run.finish()

        logger.debug('All %d epochs are finished', self.epoch)
        logger.debug("Overall training time %f", durationOverallTrain)
        logger.info("best_val_acc: %f", valBestAcc)
