import numpy as np
import torch
from torch.utils.data import BatchSampler, RandomSampler
from base import BaseTrainer
from utils import MetricTracker, get_samples
from model.loss import *
from model.metric import *
import pickle


class StepOneTrainer(BaseTrainer):
    """Trainer class"""

    def __init__(self, model, optimizer, config, dataset, loss_fn, lr_scheduler=None):
        super().__init__(model, optimizer, config, dataset)
        self.loss_fn = loss_fn
        self.data_sampler = BatchSampler(RandomSampler([sample['y'] for sample in self.dataset.train_seen]),
                                    batch_size=config['trainer']['s1_train_batch_size'],
                                    drop_last=False)
        self.len_epoch = len(self.data_sampler)

        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(len(self.data_sampler)))

        self.train_metrics = MetricTracker('loss',
                                           'seen_accuracy',
                                           'seen_precision',
                                           'seen_recall',
                                           'seen_macro_f1',
                                           writer=self.writer)
        self.valid_batch_size = config['trainer']['valid_batch_size']

    def _train_epoch(self, epoch):
        """Training logic for an epoch

        Args:
            epoch: Integer, current training epoch.

        Returns:
            A log that contains average loss and metric in this epoch.
        """
        self.model.train()
        self.train_metrics.reset()

        for batch_idx, idxs in enumerate(self.data_sampler):
            class_samples = self.dataset.seen_class
            class_protos_x, class_protos_len, class_protos_y = get_samples(class_samples, self.device, self.config['arch_step1']['args']['encoder_type'])

            query_samples = [self.dataset.train_seen[ids] for ids in idxs]
            querys_x, querys_len, querys_y = get_samples(query_samples, self.device, self.config['arch_step1']['args'][
                                                                           'encoder_type'])

            self.optimizer.zero_grad()

            protos = self.model(class_protos_x, class_protos_len, 'encode')
            querys = self.model(querys_x, querys_len, 'encode')
            loss, output = self.loss_fn(protos, querys, querys_y, self.model.tao_cos)

            # L2 regularization
            # L2_reg = torch.tensor(0., requires_grad=True).to(self.device)
            # for name, param in self.model.named_parameters():
            #     if 'lstm' in name:
            #         L2_reg = L2_reg + param.norm(p=2)
            # lambda_reg = 1e-5
            # loss_reg = lambda_reg * L2_reg

            loss_total = loss# + loss_reg

            loss_total.backward()
            self.optimizer.step()

            # train_metrics
            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
            ## loss
            self.train_metrics.update('loss', loss.item())
            ## metrics
            y = querys_y.cpu()
            y_preds = output.max(dim=1)[1].cpu()
            self.train_metrics.update('seen_accuracy', accuracy_fn(y, y_preds))
            self.train_metrics.update('seen_precision', precision_fn(y, y_preds))
            self.train_metrics.update('seen_recall', recall_fn(y, y_preds))
            self.train_metrics.update('seen_macro_f1', macro_f1_fn(y, y_preds))

            if batch_idx % self.log_step == 0:
                self.logger.debug('Train Epoch: {} Batch id: {} Loss: {:.6f}'.format(
                    epoch,
                    batch_idx,
                    loss.item()))
        log = self.train_metrics.result()

        # validation
        self.logger.info('=' * 5 + ' valid')
        val_log = self._valid_epoch(epoch)
        log.update(**{'val_' + k: v for k, v in val_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        if self.config['loss'] == 'CosT':
            self.logger.info('tao_cos : {}'.format(self.model.tao_cos))

        return log

    @torch.no_grad()
    def _valid_epoch(self, epoch):
        """Validate after training an epoch

        Args:
            epoch: Integer, current training epoch.

        Returns:
            A log that contains information about validation
        """
        self.model.eval()

        n_seen_class = self.dataset.n_seen_class
        start = - int(self.dataset.unseen_class[0]['y']) + n_seen_class

        # proto
        class_samples = self.dataset.seen_class + self.dataset.unseen_class
        class_protos_x, class_protos_len, class_protos_y = get_samples(class_samples, self.device, self.config['arch_step1']['args'][
                                                                           'encoder_type'])

        protos = self.model(class_protos_x, class_protos_len, 'encode')

        # val seen
        n_val = len(self.dataset.test_seen)
        valid_data_sampler = BatchSampler(RandomSampler(range(n_val)),
                                          batch_size=self.valid_batch_size,
                                          drop_last=False)

        ys = torch.LongTensor([])
        ps = torch.tensor([])
        for batch_idx, idxs in enumerate(valid_data_sampler):
            query_samples = [self.dataset.test_seen[ids] for ids in idxs]
            querys_x, querys_len, querys_y = get_samples(query_samples, self.device, self.config['arch_step1']['args'][
                                                                           'encoder_type'])

            querys = self.model(querys_x, querys_len, 'encode')

            loss, output = self.loss_fn(protos, querys, querys_y)

            # valid_metrics
            self.writer.set_step((epoch - 1) * len(valid_data_sampler) + batch_idx, 'valid')
            ## metrics
            y = querys_y.cpu()
            # y_pred = output.max(dim=1)[1].cpu()
            ys = torch.cat([ys, y], 0)
            ps = torch.cat([ps, output.cpu()])

        # val unseen
        n_val = len(self.dataset.test_unseen)
        valid_data_sampler = BatchSampler(RandomSampler(range(n_val)),
                                          batch_size=self.valid_batch_size,
                                          drop_last=False)

        for batch_idx, idxs in enumerate(valid_data_sampler):
            query_samples = [self.dataset.test_unseen[ids] for ids in idxs]
            querys_x, querys_len, querys_y = get_samples(query_samples, self.device, self.config['arch_step1']['args'][
                                                                           'encoder_type'])
            querys_y = querys_y + start

            querys = self.model(querys_x, querys_len, 'encode')

            loss, output = self.loss_fn(protos, querys, querys_y)

            # valid_metrics
            self.writer.set_step((epoch - 1) * len(valid_data_sampler) + batch_idx, 'valid')
            y = querys_y.cpu()
            ys = torch.cat([ys, y], 0)
            ps = torch.cat([ps, output.cpu()])

        res, _ = all_metric(ys,
                            ps,
                            len(self.dataset.test_seen),
                            self.dataset.n_seen_class,
                            self.dataset.n_unseen_class)
        print('\t'.join(['{:.2f}'.format(round(_ * 100, 2)) for _ in res.values()]))

        if self.mnt_mode != 'off' and res['total_acc_hm'] >= self.mnt_best:
            with open('./data/{}/protos_{}.pkl'.format(self.dataset.cropus.data_name, self.config['arch_step1']['args']['encoder_type']), 'wb') as f:
                pickle.dump(protos.detach().cpu(), f)
                print('Save proto pickle file at {} epoch'.format(epoch))
        return res


class StepTwoTrainer(BaseTrainer):
    """Trainer class"""

    def __init__(self, model, optimizer, config, dataset, loss_fn, data_sampler, lr_scheduler=None):
        super().__init__(model, optimizer, config, dataset)
        self.loss_fn = loss_fn
        self.data_sampler = data_sampler
        self.len_epoch = len(self.data_sampler)

        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(len(data_sampler)))

        self.train_metrics = MetricTracker('loss',
                                           'seen_accuracy',
                                           'seen_precision',
                                           'seen_recall',
                                           'seen_macro_f1',
                                           'unseen_accuracy',
                                           'unseen_precision',
                                           'unseen_recall',
                                           'unseen_macro_f1',
                                           'total_seen_recall',
                                           'total_unseen_recall',
                                           'total_acc_hm',
                                           'total_f1_hm',
                                           writer=self.writer)
        self.valid_batch_size = config['trainer']['valid_batch_size']

    def _train_epoch(self, epoch):
        """Training logic for an epoch

        Args:
            epoch: Integer, current training epoch.

        Returns:
            A log that contains average loss and metric in this epoch.
        """
        self.model.train()
        self.train_metrics.reset()

        for batch_idx, (novel_class_idxs, novel_query, memory_query) in enumerate(self.data_sampler):

            self.optimizer.zero_grad()
            class_samples = self.dataset.seen_class
            class_protos_x, class_protos_len, class_protos_y = get_samples(class_samples, self.device, self.config['arch_step2']['args']['encoder_type'])
            protos = self.model(class_protos_x, class_protos_len, 'encode')

            novel_class_idxs = torch.tensor(novel_class_idxs)
            memory_class_idxs = (~(torch.arange(self.dataset.n_seen_class)[..., None] == novel_class_idxs).any(
                -1)).nonzero(as_tuple=False).squeeze(-1)
            novel_protos = protos[novel_class_idxs]

            protos, memory_protos, novel_protos, after_memory_protos, after_novel_protos, v, loss_r = \
                self.model(novel_protos, memory_class_idxs, novel_class_idxs,'transfer')

            novel_query_samples = [self.dataset.train_seen[ids] for ids in novel_query]
            novel_querys_x, novel_querys_len, novel_querys_y = get_samples(novel_query_samples, self.device, self.config['arch_step2']['args']['encoder_type'])
            memory_query_samples = [self.dataset.train_seen[ids] for ids in memory_query]
            memory_querys_x, memory_querys_len, memory_querys_y = get_samples(memory_query_samples, self.device, self.config['arch_step2']['args']['encoder_type'])

            re_index = torch.sort(torch.cat([memory_class_idxs, novel_class_idxs], 0).to(self.device))[1]
            novel_querys_y = re_index[novel_querys_y]
            memory_querys_y = re_index[memory_querys_y]

            # if epoch >= 15:
            novel_querys = self.model(novel_querys_x, novel_querys_len, memory_protos, novel_protos,
                                                 after_memory_protos, after_novel_protos, v, 'adapt')
            memory_querys = self.model(memory_querys_x, memory_querys_len, memory_protos, novel_protos,
                                                  after_memory_protos, after_novel_protos, v, 'adapt')

            loss_n, output_n = self.loss_fn(protos, novel_querys, novel_querys_y, self.model.tao_cos)
            loss_m, output_m = self.loss_fn(protos, memory_querys, memory_querys_y, self.model.tao_cos)

            # L2 regularization
            # L2_reg = torch.tensor(0., requires_grad=True).to(self.device)
            # for name, param in self.model.named_parameters():
            #     if ('lstm' in name) or ('attn' in name) or ('generator' in name):
            #         L2_reg = L2_reg + param.norm(p=2)
            # lambda_reg = 1e-5
            # loss_reg = lambda_reg * L2_reg

            loss = loss_n + loss_m + loss_r
            loss_total = loss# + loss_reg

            loss_total.backward()
            self.optimizer.step()

            # train_metrics
            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
            ## loss
            self.train_metrics.update('loss', loss.item())
            ## metrics

            y = memory_querys_y.cpu()
            y_pred = output_m.max(dim=1)[1].cpu()

            self.train_metrics.update('seen_accuracy', accuracy_fn(y, y_pred))
            self.train_metrics.update('seen_precision', precision_fn(y, y_pred))
            self.train_metrics.update('seen_recall', recall_fn(y, y_pred))
            self.train_metrics.update('seen_macro_f1', macro_f1_fn(y, y_pred))
            self.train_metrics.update('total_seen_recall', recall_fn(torch.ones_like(y),
                                                                     y_pred < self.dataset.n_seen_class - len(
                                                                         novel_class_idxs),
                                                                     average='binary'))

            y = novel_querys_y.cpu()
            y_pred = output_n.max(dim=1)[1].cpu()
            self.train_metrics.update('unseen_accuracy', accuracy_fn(y, y_pred))
            self.train_metrics.update('unseen_precision', precision_fn(y, y_pred))
            self.train_metrics.update('unseen_recall', recall_fn(y, y_pred))
            self.train_metrics.update('unseen_macro_f1', macro_f1_fn(y, y_pred))
            self.train_metrics.update('total_unseen_recall', recall_fn(torch.ones_like(y),
                                                                     y_pred >= self.dataset.n_seen_class - len(
                                                                         novel_class_idxs),
                                                                     average='binary'))
            self.train_metrics.update('total_acc_hm', HM_fn(self.train_metrics.avg('seen_accuracy'),
                                                            self.train_metrics.avg('unseen_accuracy')))
            self.train_metrics.update('total_f1_hm', HM_fn(self.train_metrics.avg('seen_macro_f1'),
                                                            self.train_metrics.avg('unseen_macro_f1')))

            if batch_idx % self.log_step == 0:
                self.logger.debug('Train Epoch: {} Batch id: {} Loss: {:.6f}'.format(
                    epoch,
                    batch_idx,
                    loss.item()))

        log = self.train_metrics.result()

        # validation
        self.logger.info('=' * 5 + ' valid')
        val_log = self._valid_epoch(epoch)
        log.update(**{'val_' + k: v for k, v in val_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        if self.config['loss'] == 'CosT':
            self.logger.info('tao_cos : {}'.format(self.model.tao_cos))
        self.logger.info('tao_attn: {}'.format(self.model.tao_attn))

        return log

    @torch.no_grad()
    def _valid_epoch(self, epoch):
        """Validate after training an epoch

        Args:
            epoch: Integer, current training epoch.

        Returns:
            A log that contains information about validation
        """


        self.model.eval()

        n_seen_class = self.dataset.n_seen_class
        start = - int(self.dataset.unseen_class[0]['y']) + n_seen_class
        ys = torch.LongTensor([])
        ps = torch.tensor([])

        # proto
        memory_query_samples = self.dataset.seen_class
        memory_protos_x, memory_protos_len, memory_protos_y = get_samples(memory_query_samples, self.device, self.config['arch_step2']['args']['encoder_type'])
        novel_samples = self.dataset.unseen_class
        novel_protos_x, novel_protos_len, novel_protos_y = get_samples(novel_samples, self.device, self.config['arch_step2']['args']['encoder_type'])

        novel_protos = self.model(novel_protos_x, novel_protos_len, 'encode')

        protos, memory_protos, novel_protos, after_memory_protos, after_novel_protos, v, loss_r = \
            self.model(novel_protos, memory_protos_y, novel_protos_y,'transfer')

        # val seen
        n_val = len(self.dataset.test_seen)
        valid_data_sampler = BatchSampler(RandomSampler(range(n_val)),
                                          batch_size=self.valid_batch_size,
                                          drop_last=False)

        for batch_idx, idxs in enumerate(valid_data_sampler):
            query_samples = [self.dataset.test_seen[ids] for ids in idxs]
            querys_x, querys_len, querys_y = get_samples(query_samples, self.device, self.config['arch_step2']['args']['encoder_type'])

            querys = self.model(querys_x, querys_len, memory_protos, novel_protos, after_memory_protos, after_novel_protos, v, 'adapt')

            loss, output = self.loss_fn(protos, querys, querys_y)

            ## metrics
            y = querys_y.cpu()
            ys = torch.cat([ys, y], 0)
            ps = torch.cat([ps, output.cpu()])


        # val unseen
        n_val = len(self.dataset.test_unseen)
        valid_data_sampler = BatchSampler(RandomSampler(range(n_val)),
                                          batch_size=self.valid_batch_size,
                                          drop_last=False)

        for batch_idx, idxs in enumerate(valid_data_sampler):
            query_samples = [self.dataset.test_unseen[ids] for ids in idxs]
            querys_x, querys_len, querys_y = get_samples(query_samples, self.device, self.config['arch_step2']['args']['encoder_type'])
            querys_y = querys_y + start

            querys = self.model(querys_x, querys_len, memory_protos, novel_protos,
                                                   after_memory_protos, after_novel_protos, v, 'adapt')

            loss, output = self.loss_fn(protos, querys, querys_y)

            ## metrics
            y = querys_y.cpu()
            ys = torch.cat([ys, y], 0)
            ps = torch.cat([ps, output.cpu()])

        res, _ = all_metric(ys,
                            ps,
                            len(self.dataset.test_seen),
                            self.dataset.n_seen_class,
                            self.dataset.n_unseen_class)
        print('\t'.join(['{:.2f}'.format(round(_ * 100, 2)) for _ in res.values()]))
        return res