# -*- coding: utf-8 -*-

import itertools
import os
from abc import ABCMeta, abstractmethod
from typing import Any, List, Union

import torch
import torch.nn as nn

from ..common.dataclass_options import OptionsBase, argfield
from ..common.logger import LOGGER
from ..common.utils import AverageMeter, DotDict, ProgressReporter
from .optim import OptimizerOptions
from .optim.scheduler import AdvancedLearningOptions, AdvancedScheduler
from .utils import set_random_seed

TMP_TOKEN = '__TMP__'


def to_device(state_dict, device):
    for key, value in state_dict.items():
        if torch.is_tensor(value):
            state_dict[key] = value.to(device)
        elif isinstance(value, dict):
            to_device(value, device)


def get_device(gpu: Union[str, bool]):
    device = None
    if isinstance(gpu, str):
        if gpu == 'auto':
            gpu = torch.cuda.is_available()
        elif gpu == 'none':
            gpu = False
        else:
            device = gpu
            gpu = True
    if device is None:
        device = 'cuda:0' if gpu else 'cpu'
    return torch.device(device)


class RestoreToBestSignal(Exception):
    pass


class Metrics:
    __slots__ = ('name', 'mode', 'value')

    def __init__(self, name, mode, init_value=0):
        self.name = name
        self.mode = mode
        self.value = init_value

    def __str__(self):
        return f'[{self.name} {self.value:.4f}]'

    def is_far_worse(self, value):
        return (self.mode == 'max' and self.value > value * 10) \
            or (self.mode == 'min' and self.value < value / 10)

    def is_better(self, value):
        return (self.mode == 'max' and self.value < value) \
            or (self.mode == 'min' and self.value > value)

    def update(self, value):
        if self.is_better(value):
            self.value = value
            return True
        return False


class BaseHyperParams(OptionsBase):
    seed: int = 1996

    optimizer: OptimizerOptions
    advanced_learning: AdvancedLearningOptions


class BaseOptions(OptionsBase):
    base_path: str
    checkpoint_frequency: int = 500
    log_frequency: int = 10
    max_steps: int = 100000
    num_old_checkpoints: int = 0

    gpu: Union[bool, str] = argfield('auto', active_time='both')
    use_debugger: bool = argfield(True, active_time='both')

    train_paths: List[str] = argfield(help='Paths of training set')
    dev_path: str = argfield(help='Path of training set')
    test_paths: List[str] = argfield(help='Paths of test set', active_time='predict')
    output_prefix: str = argfield(active_time='predict')

    hyper_params: BaseHyperParams


class ModelBase(metaclass=ABCMeta):
    network: nn.Module
    statistics: Any

    Options = BaseOptions
    HyperParams = BaseHyperParams

    METRICS_MODE = 'max'
    METRICS_NAME = '???'

    def __init__(self, options: Options, training_session=None):
        assert options.log_frequency % options.hyper_params.advanced_learning.update_frequency == 0
        assert options.checkpoint_frequency % options.log_frequency == 0

        self.options = options
        self.hyper_params = options.hyper_params

        self.plugins = None
        self.logger = training_session.logger if training_session else LOGGER

        self.device = get_device(options.gpu)

        # training staffs
        self.session = training_session
        self.metrics = Metrics(self.METRICS_NAME, self.METRICS_MODE)

        self._train_hooks = {}

    def remove_hook(self, name, fn):
        self._train_hooks[name].remove(fn)

    def add_hook(self, name, fn):
        self._train_hooks[name].append(fn)

    def run_hooks(self, name, *args, **kwargs):
        return [hook(*args, **kwargs) for hook in self._train_hooks[name]]

    def get_path(self, *name):
        if self.session is not None:
            return self.session.get_path('outputs', *name)
        prefix = self.options.output_prefix
        return os.path.join(os.path.dirname(prefix), os.path.basename(prefix), *name)

    @classmethod
    def make_release(self, saved_state):
        return {key: value for key, value in saved_state.items()
                if key not in ['scheduler', 'optimizer']}

    def state_dict(self):
        return {
            'network': self.network.state_dict(),
            'statistics': self.statistics.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'optimizer': self.optimizer.state_dict()
        }

    def load_state_dict(self, saved_state):
        self.network.load_state_dict(saved_state['network'])
        if self.options.training:
            self.scheduler.load_state_dict(saved_state['scheduler'])
            self.optimizer.load_state_dict(saved_state['optimizer'])

    def move_to_device(self, device=None):
        if device is None:
            device = self.device

        self.network.to(device)
        if self.options.training:
            to_device(self.optimizer.state, device)

        return device

    def setup_hooks(self, obj):
        for name in self._train_hooks:
            hook = getattr(obj, f'{name}_hook', None)
            if hook:
                self.add_hook(name, hook)
                self.logger.info('register %s hook: %s', name, obj.__class__)

    def setup(self, saved_state):
        self.initialize(saved_state)
        self.move_to_device()

        if self.options.training:
            self._train_hooks.update(before_step=[], before_epoch=[], before_eval=[],
                                     after_step=[], after_epoch=[], after_eval=[])

            self.network.apply(self.setup_hooks)
            self.setup_hooks(self.scheduler)
            self.setup_hooks(self)

    @abstractmethod
    def initialize(self, saved_state):
        pass

    @abstractmethod
    def iter_batches(self, path, mode):
        pass

    def iter_train_batches(self):
        yield from itertools.chain(*(self.iter_batches(path, mode='train')
                                     for path in self.options.train_paths))

    def get_trainable_params(self):
        return [param for param in self.network.parameters() if param.requires_grad]

    def get_optimizer_and_scheduler(self, trainable_params):
        hyper_params = self.hyper_params
        optimizer = hyper_params.optimizer.create(trainable_params)
        scheduler = AdvancedScheduler(hyper_params.advanced_learning, optimizer,
                                      max_steps=self.options.max_steps,
                                      mode=self.metrics.mode)
        return optimizer, scheduler

    def compute_stats(self, average_meter, batch_samples, inputs, outputs):
        stats = average_meter.avgs()
        for index, param_group in enumerate(self.optimizer.param_groups):
            stats[f'lr_group{index}'] = float(param_group['lr'])
        return stats

    def run_batch(self, batch_samples, inputs):
        inputs = DotDict({name: (value.to(self.device) if torch.is_tensor(value) else value)
                          for name, value in inputs.items()})

        return self.network(batch_samples, inputs)

    def evaluate(self, data_path, samples, outputs, output_prefix):
        raise NotImplementedError()

    def predict(self, data_path, samples, outputs, output_prefix):
        raise NotImplementedError()

    def split_outputs(self, batch_samples, outputs):
        raise NotImplementedError()

    def evaluate_entry(self, path=None, mode=None):
        if path is None:
            path = self.options.dev_path
            mode = 'dev'
        else:
            assert mode is not None

        total_outputs = []
        total_samples = []

        self.network.eval()
        with ProgressReporter(step=1, prompt='finished samples') as progress:
            with torch.no_grad():
                for batch_samples, inputs in self.iter_batches(path, mode):
                    outputs = self.run_batch(batch_samples, inputs)

                    total_samples.extend(batch_samples)
                    total_outputs.extend(self.split_outputs(batch_samples, outputs))

                    progress.tick(len(batch_samples))

        outputs = [None] * len(total_outputs)
        samples = [None] * len(total_samples)
        for sample, output in zip(total_samples, total_outputs):
            index = sample.original_index
            assert samples[index] is None and outputs[index] is None
            samples[index] = sample
            outputs[index] = output

        if mode == 'predict':
            output_prefix = self.options.output_prefix
            fn = self.predict
        else:
            output_prefix = self.get_path(TMP_TOKEN)
            fn = self.evaluate

        return fn(path, samples, outputs, output_prefix=output_prefix)

    def normalize_loss(self, batch_samples, inputs, loss, update_frequency):
        if update_frequency > 1:
            return loss / update_frequency
        return loss

    def _train_entry_internal(self, epoch, meter, progress):
        session = self.session
        logger = self.logger

        max_steps = self.options.max_steps
        checkpoint_frequency = self.options.checkpoint_frequency
        update_frequency = self.hyper_params.advanced_learning.update_frequency
        clip_grad_norm = self.hyper_params.advanced_learning.clip_grad_norm

        meter.clear()
        for batch_samples, inputs in self.iter_train_batches():
            self.run_hooks('before_step', session.global_step, max_steps, epoch, logger=self.logger)

            try:
                outputs = self.run_batch(batch_samples, inputs)

                loss = self.normalize_loss(batch_samples, inputs, outputs.loss, update_frequency)

                loss.backward()
                outputs.loss = loss.item()
            except RuntimeError as err:
                if 'CUDA out of memory' in str(err):
                    logger.exception('If this appears frequently, '
                                     'a smaller batch size is appreciated.')
                    continue
                else:
                    raise

            meter.add('loss', outputs.loss * len(batch_samples), len(batch_samples))

            session.step(self.compute_stats(meter, batch_samples, inputs, outputs))
            progress.tick()

            if session.global_step % update_frequency != 0:
                continue

            if clip_grad_norm > 0:
                total_norm = nn.utils.clip_grad_norm_(self.network.parameters(), clip_grad_norm)
                meter.add('total_norm', total_norm)

            self.optimizer.step()
            self.optimizer.zero_grad()

            step = session.global_step
            self.run_hooks('after_step', step, max_steps, epoch, logger=self.logger)

            if step != 0 and step % checkpoint_frequency == 0:
                self.run_hooks('before_eval', step, max_steps, epoch, logger=self.logger)

                metric_value, output_files = self.evaluate_entry()
                session.try_save(metric_value, output_files)

                self.network.train()

                if any(self.run_hooks('after_eval', step, max_steps, epoch,
                                      metric_value=metric_value,
                                      logger=self.logger)):
                    return True

        return False

    def train_entry(self):

        session = self.session
        options = self.options

        max_steps = options.max_steps
        epoch = self.scheduler.last_epoch + 1

        set_random_seed(self.hyper_params.seed, logger=self.logger)

        meter = AverageMeter()

        def message_fn(_):
            return ' | '.join(f'{name}: {value:.4f}' for name, value in meter.avgs().items())

        self.network.train()
        with ProgressReporter(stop=options.max_steps,
                              step=options.log_frequency,
                              start=session.global_step,
                              message_fn=message_fn,
                              print_time=True, newline=True) as progress:
            while True:
                if session.global_step >= max_steps:
                    break

                self.run_hooks('before_epoch', session.global_step, max_steps, epoch,
                               logger=self.logger)

                progress._prompt = f'Epoch {epoch}'
                if self._train_entry_internal(epoch, meter, progress):
                    break

                self.run_hooks('after_epoch', session.global_step, max_steps, epoch,
                               logger=self.logger)
                epoch += 1
