#%%
import os
import math
import json
import torch
import datetime
import numpy as np
import torch.nn as nn
from time import time
from tqdm import tqdm, trange
import torchtext.data as torchdata
from torch.nn.functional import softmax
from torch.utils.tensorboard import SummaryWriter

from utils import Logger
from utils import log_running_time, ensure_dir, get_desc_from_result, get_log_from_result
from data import metric
#%%
start = time()
class Trainer():
    def __init__(self, config, data, model, optimizer, lr_scheduler, resume):
        super(Trainer, self).__init__()
        self.config = config
        # ------ Config ------
        self.task = config['task']
        self.exp_name = config['exp_name']

        # ------ Setup Device ------
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.n_gpu = torch.cuda.device_count()
        # self.device = torch.device("cpu")
        # ------ Load Data and Model ------
        self.data = data
        self.train_size = len(data.train_iter)
        self.output_mode = config['data_loader']['mode']
        self.label_map = data.get_label_map(self.task)
        self.model = model.to(self.device)
        if self.n_gpu > 1:
            self.model = torch.nn.DataParallel(self.model)
        self.loss = nn.CrossEntropyLoss()
        # ------ Set Optimizer and Scheduler ------
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        
        # ------ Set Trainer ------
        self.trainer_config = config['trainer']
        self.evaluate_dev = config['trainer']['evaluate_dev']
        if config['trainer']['lr_step'] == -1:
            self.lr_step = self.train_size//10
        else:
            self.lr_step = config['trainer']['lr_step']
        self.max_grad_norm = config["trainer"]["max_grad_norm"]
        self.gradient_accumulation_steps = config["trainer"]["gradient_accumulation_steps"]
        self.logging_steps = config['trainer']['logging_steps']
        self.validate_steps = config['trainer']['validate_steps']
        self.validate_thelta = config['trainer']['validate_thelta']
        self.save_dev = config['trainer']['save_dev']
        self.save_steps = config['trainer']['save_steps']
        self.save_thelta = config['trainer']['save_thelta']
        self.output_result = config['trainer']['output_result']
        
        # ------ Set Logger ------
        self.log_dir = os.path.join(config['logger']['log_dir'], config['task'], config['logger']['exp_dir'])
        ensure_dir(self.log_dir)
        self.log_path = os.path.join(self.log_dir, config['exp_name']+'.tsv')
        self.logger = Logger(filename=self.log_path).logger
        self.tb_path = os.path.join(config['logger']['tb_dir'], self.task, config['logger']['exp_dir'], config['exp_name'])
        self.tb_writer = SummaryWriter(self.tb_path)

        # ------ Set Checkpoint ------
        start_time = datetime.datetime.now().strftime('%m%d_%H%M%S')
        self.checkpoint_dir = os.path.join(config['logger']['save_dir'], config['task'], config['logger']['exp_dir'], config['exp_name'], start_time)
        ensure_dir(self.checkpoint_dir)
        config_save_path = os.path.join(self.checkpoint_dir, 'config.json')
        with open(config_save_path, 'w') as handle:
            json.dump(config, handle, indent=4, sort_keys=False)
        
        # ------ Set ResultOutputs ------
        self.result_dir = os.path.join(config['logger']['result_dir'], config['task'], config['logger']['exp_dir'])
        ensure_dir(self.result_dir)
        self.result_path = os.path.join(self.result_dir, config['exp_name']+'.tsv')
        # ------ Load Resume ------
        if resume:
            self._resume_checkpoint(resume)
    
    def build_data(self, batch):
        # Note: the size of sequential data iter is [batch_seq_len, batch_size]
        max_len = self.config['embedder']['max_len']
        use_char = self.config['embedder']['args']['use_char']
        use_em = self.config['embedder']['args']['use_em']
        input_s, len_s = batch.s
        input_t, len_t = batch.t
        if max_len >= 0:
            if input_s.size()[1] > max_len:
                input_s = input_s[:, :max_len]
                for i in range(len_s.size(0)):
                    len_s[i] = min(len_s[i], max_len)
            if input_t.size()[1] > max_len:
                input_t = input_t[:, :max_len]
                for i in range(len_t.size(0)):
                    len_t[i] = min(len_t[i], max_len)
        if use_char:
            char_s = torch.LongTensor(self.data.characterize(input_s)).to(self.device)
            char_t = torch.LongTensor(self.data.characterize(input_t)).to(self.device)
        else:
            char_s, char_t = None, None
        if use_em:
            em_s = torch.FloatTensor(self.data.exact_match(input_s, input_t)).to(self.device)
            em_t = torch.FloatTensor(self.data.exact_match(input_t, input_s)).to(self.device)
        else:
            em_s, em_t = None, None
        inputs = {
            'input_s': input_s.to(self.device), 
            'input_t': input_t.to(self.device), 
            'len_s': len_s.to(self.device), 
            'len_t': len_t.to(self.device),
            'char_s': char_s,
            'char_t': char_t,
            'em_s': em_s,
            'em_t': em_t
            #'labels': labels
        }
        
        return inputs

    def train(self):
        num_train_epochs = self.trainer_config["num_train_epochs"]
        # Train!
        global_step = 0
        tr_loss, logging_loss = 0.0, 0.0
        best_dev_metric = {'main': 0}
        print ("***** Running training *****")
        print (self.config)
        #self.optimizer.zero_grad()
        # epoch_iterator = trange(int(num_train_epochs), desc="Epoch", ncols=100, leave=False)
        for epoch in range(int(num_train_epochs)):
            # epoch_iterator.set_description("Epoch: {}".format(epoch))
            # batch_iterator = tqdm(self.data.train_iter, desc="Best Result:{}".format(get_desc_from_result(best_history_metric)), ncols=150, leave=False)
            for step, batch in enumerate(self.data.train_iter): #enumerate(batch_iterator):
                self.model.train()
                # ----- Build Inputs ------
                inputs = self.build_data(batch)
                
                # ------ Training Process ------
                outputs = self.model(**inputs)
                #loss = outputs[0]
                loss = self.loss(outputs[0], batch.label.to(self.device))
                if self.gradient_accumulation_steps > 1:
                    loss = loss / self.gradient_accumulation_steps
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
                tr_loss += loss.item()

                if (step + 1) % self.gradient_accumulation_steps == 0:
                    # self.optimizer.zero_grad()
                    self.optimizer.step()
                    # ------ LR Scheduler Process ------
                    if global_step % self.lr_step == 0 and global_step !=0:
                        self.lr_scheduler.step()
                    # self.lr_scheduler.step()  # Update learning rate schedule
                    self.optimizer.zero_grad() # model
                    
                    # ------ Log Process ------
                    if global_step % self.logging_steps == 0:
                        self.tb_writer.add_scalar('lr', self.optimizer.state_dict()['param_groups'][0]['lr'], global_step)
                        self.tb_writer.add_scalar('loss', (tr_loss - logging_loss)/self.logging_steps, global_step)
                        logging_loss = tr_loss
                    # ------ Update validation steps ------
                    if best_dev_metric['main'] >= self.validate_thelta:
                        self.validate_steps = self.validate_steps // 2
                        self.validate_thelta += 0.1
                    # ------ Validation Process ------
                    if not self.evaluate_dev:
                        continue
                    if global_step % self.validate_steps == 0 and self.validate_steps != -1:
                        # batch_iterator.set_description("Evaluating Dev")
                        dev_result = self.evaluate(test=False, output_result=self.output_result)
                        # batch_iterator.set_description("Evaluating Dev Done")
                        # ------ Save Best CheckPoints ------
                        if dev_result['main'] >= best_dev_metric.get('main', 0):
                            for key in best_dev_metric:
                                best_dev_metric[key] = dev_result[key]
                            if self.save_dev and best_dev_metric['main'] >= self.save_thelta:
                                self._save_checkpoint(global_step)
                        
                        # ------ Log Evaluate Result ------
                        for key in dev_result:
                            self.tb_writer.add_scalar('dev_{}'.format(key), dev_result[key], global_step)
                        epoch_log = '{}\t{}'.format(epoch, global_step)
                        dev_log = get_log_from_result(dev_result)
                        self.logger.info(epoch_log + dev_log)
                    # if global_step % self.save_steps == 0 and self.save_steps != -1:
                        # Save model checkpoint
                        # self._save_checkpoint(global_step)
                    global_step += 1
        print ("Best Dev Metric: {}".format(best_dev_metric['main']))
    
    def evaluate(self, test=False, output_result=True):
        # Loop to handle MNLI double evaluation (matched, mis-matched)
        eval_task = self.task
        results = {}
        if not test:
            eval_dataset = self.data.dev 
            eval_dataloader = self.data.dev_iter
        else:
            eval_dataset = self.data.test
            eval_dataloader = self.data.test_iter
        # Eval!
        eval_loss = 0.0
        nb_eval_steps = 0
        preds = None
        out_label_ids = None
        time_list = []
        for batch in eval_dataloader:
            self.model.eval()
            with torch.no_grad():
                inputs = self.build_data(batch)
                # start1 = time()
                outputs = self.model(**inputs)
                # end1 = time()
                #tmp_eval_loss, logits = outputs[:2]
                
                logits = outputs[0]
                # time_list.append(end1 - start1)
                tmp_eval_loss = self.loss(logits, batch.label.to(self.device))
                eval_loss += tmp_eval_loss.mean().item()
            nb_eval_steps += 1
            if preds is None:
                preds = logits.detach().cpu().numpy()
                out_label_ids = batch.label.detach().cpu().numpy()
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
                out_label_ids = np.append(out_label_ids, batch.label.detach().cpu().numpy(), axis=0)
        # print (sum(time_list)/ len(time_list))  
        eval_loss = eval_loss / nb_eval_steps
        if self.output_mode == "classification":
            preds = np.argmax(preds, axis=1)
        elif self.output_mode == "regression":
            preds = np.squeeze(preds)
        else:
            preds = softmax(torch.tensor(preds), -1)
            preds = preds.numpy()[:,1]

        eval_results = metric(self.task, preds, out_label_ids)
        if output_result:
            if test:
                suffix = "test"
            else:
                suffix = 'dev'
            with open(self.result_path+suffix, 'w') as out:
                out.write("index"+"\t"+"prediction"+"\n")
                for index, pred in enumerate(preds):
                    pred = str(pred)
                    out.write('\t'.join([str(index), pred])+'\n')
        eval_results["eval_loss"] = eval_loss
        return eval_results

    def predict(self, s, t):
        example = [s, t]
        fields = [('s', self.data.TEXT), ('t', self.data.TEXT)]
        example = torchdata.Example().fromlist(example, fields)
        test_data = torchdata.Dataset([example], fields)
        test_iter = torchdata.BucketIterator(
            test_data,
            batch_size=1,
            shuffle=False,
            sort=False)
        self.model.eval()
        with torch.no_grad():
            test_iter.device = self.device
            for index, batch in enumerate(test_iter):
                kwargs = self.build_data(batch)
                outputs = self.model(**kwargs)
                score, pred = outputs[0].max(dim=1)
        return outputs
    
    def _save_checkpoint(self, global_steps, save_best=True):
        arch = type(self.model).__name__
        state = {
            'arch': arch,
            'global_steps': global_steps,
            'state_dict': self.model.state_dict(),
            'config': self.config
        }
        if save_best==False:
            path = os.path.join(self.checkpoint_dir, 'model_{}.pth'.format(global_steps))
            torch.save(state, path)
        else:
            best_path = os.path.join(self.checkpoint_dir, 'model_best_dev.pth')
            torch.save(state, best_path)

    def _resume_checkpoint(self, resume_path):
        checkpoint = torch.load(resume_path)
        self.model.load_state_dict(checkpoint['state_dict'])

# %%
