# -*- coding: utf-8 -*-

import os
import shutil
from argparse import ArgumentParser
from typing import Optional

import torch

from ..common.dataclass_options import SingleOptionsParser
from ..common.debug_console import DebugConsoleWrapper
from ..common.logger import get_logger, open_file
from ..common.utils import query_yes_no, set_current_process_name
from .model_base import TMP_TOKEN, ModelBase, RestoreToBestSignal
from .utils import import_class, print_version, summary_model


class TrainSession:
    def __init__(self, options_or_path, entry_class: ModelBase,
                 logger=None,
                 pack_model_fn=None,
                 session_name: Optional[str] = None,
                 no_confirm=False):
        SessionOptions = entry_class.Options

        restore = isinstance(options_or_path, str)
        if restore:
            assert os.path.isdir(options_or_path), \
                'When restoring session, the path to model is required to be a directory.'

            options = SessionOptions.from_file(os.path.join(options_or_path, 'config'))
            options.base_path = options_or_path
        else:
            options = options_or_path
            assert isinstance(options, SessionOptions)

        assert options.training, 'options should be in training mode'

        self.options = options

        # set useful paths
        self.working_dir = options.base_path
        self.checkpoint_dir = self.get_path('checkpoints')
        self.summaries_dir = self.get_path('summaries')

        if session_name is None:
            session_name = os.path.basename(self.working_dir)
        self.session_name = session_name

        os.makedirs(self.checkpoint_dir, exist_ok=True)
        os.makedirs(self.get_path('outputs'), exist_ok=True)

        self.writer = None
        if logger is None:
            logger = get_logger(name=self.session_name, files=self.get_path('train.log'),
                                mode=('a' if restore else 'w'))
        self.logger = logger

        self.entry_object = entry_class(options, self)
        self.metrics = self.entry_object.metrics
        self.pack_model_fn = pack_model_fn

        if restore:
            self._restore(no_confirm=no_confirm)
        else:
            self._init()

        summary_model(self.entry_object, logger=logger)

    def _create_writer(self, purge_step=None):
        try:
            try:
                from torch.utils.tensorboard import SummaryWriter
            except Exception:
                from tensorboardX import SummaryWriter
            return SummaryWriter(log_dir=self.summaries_dir, purge_step=purge_step)
        except Exception:
            self.logger.info('No tensorboard support. Need tensorboardX or torch>=1.1.0')

    def _restore(self, checkpoint_path=None, no_confirm=False):
        logger = self.logger

        if checkpoint_path is None:
            checkpoints = self.list_checkpoints(include_best=True)
            assert checkpoints, 'can not find any checkpoint files'
            checkpoint_path = checkpoints[0]

        logger.info('Session restores from %s', checkpoint_path)

        saved_state = torch.load(checkpoint_path, map_location=lambda storage, _: storage)

        self.global_step = saved_state['global_step']
        self.writer = self._create_writer(self.global_step)

        if no_confirm:
            old_options = self.options.__class__()
            old_options.load_state_dict(saved_state['options'])
            diffs = self.options.diff_options(old_options)
            if diffs:
                logger.warning('****************Warning*****************')
                for key, new_value, old_value in diffs:
                    logger.warning('In saved model %s = %s', key, old_value)
                    logger.warning('But now it is %s', new_value)
                logger.warning('****************************************')
                if query_yes_no('Options may be incorrect, continue?') == 'no':
                    raise RuntimeError('Options are not compatiable !')

        self.entry_object.setup(saved_state['object'])

        metrics = self.metrics
        if metrics is not None:
            self.metrics.value = saved_state[metrics.name]
            self.logger.info('Start from step: %d (%s)', self.global_step, metrics)
        else:
            self.logger.info('Start from step: %d', self.global_step)

    def _init(self):
        assert not self.list_checkpoints(), 'checkpoint file(s) already exist !!!'

        self.global_step = 0
        self.writer = self._create_writer()

        self.options.to_file(os.path.join(self.working_dir, 'config'))

        self.entry_object.setup(None)

    def list_checkpoints(self, include_best=False):
        names = []
        if not os.path.exists(self.checkpoint_dir):
            return names

        for filename in os.listdir(self.checkpoint_dir):
            parts = filename.split('.')
            if len(parts) != 3 or parts[0] != 'ckpt' or parts[2] != 'pt':
                continue
            try:
                step = int(parts[1])
            except Exception:
                continue
            names.append((step, os.path.join(self.checkpoint_dir, filename)))

        checkpoints = [filename for _, filename in sorted(names, reverse=True)]
        if include_best:
            best_path = os.path.join(self.checkpoint_dir, 'best.pt')
            if os.path.exists(best_path):
                checkpoints.append(best_path)

        return checkpoints

    def run(self):
        set_current_process_name(self.session_name)
        print_version(self.logger)

        def _run():
            while True:
                try:
                    return self.entry_object.train_entry()
                except RestoreToBestSignal:
                    self.logger.critical('Receive RestoreToBestSignal signal !!!')
                    checkpoints = self.list_checkpoints(include_best=True)
                    if checkpoints:
                        self._restore(checkpoints[-1])
                    else:
                        self._init()

        if self.options.use_debugger:
            DebugConsoleWrapper()(_run)
        else:
            _run()

    def close(self):
        if self.writer is not None:
            self.writer.close()

    def step(self, stats, prefix=''):
        self.global_step += 1
        step = self.global_step
        if step % self.options.log_frequency == 0:
            self.add_summaries(stats.keys(), stats.values(), prefix=prefix)

    def save(self, path):
        state = {'global_step': self.global_step,
                 'options': self.options.state_dict(),
                 'object': self.entry_object.state_dict()}

        metrics = self.metrics
        if metrics is not None:
            state[metrics.name] = metrics.value

        with open_file(path, 'wb') as fp:
            torch.save(state, fp)

    def try_pack(self):
        pass

    def try_save(self, metrics_value=None, output_files=None):
        def _copy_files(name, delete):
            action = shutil.move if delete else shutil.copy2
            try:
                for filename in output_files:
                    action(filename, filename.replace(TMP_TOKEN, name))
            except Exception:
                self.logger.exception('failed to move/copy files: %s', output_files)

        options = self.options
        metrics = self.metrics

        step = self.global_step

        best_model_path = None
        model_paths = []

        num_old_checkpoints = options.num_old_checkpoints
        if num_old_checkpoints > 0:
            model_paths.append(os.path.join(self.checkpoint_dir, f'ckpt.{step}.pt'))

        if metrics is not None and metrics_value is not None:
            self.logger.info('best = %f | current = %f', self.metrics.value, metrics_value)
            self.add_summary(metrics.name, metrics_value)

            if metrics.update(metrics_value):
                best_model_path = os.path.join(self.checkpoint_dir, 'best.pt')
                model_paths.append(best_model_path)

        for path in model_paths:
            self.save(path)

        if best_model_path is not None:
            _copy_files('best', delete=False)
            if self.pack_model_fn is not None:
                self.pack_model_fn(best_model_path, self.entry_object.__class__)

        _copy_files(f'{step}_{metrics_value:.4f}', delete=True)
        if num_old_checkpoints > 0:
            # remove old checkpoint files
            for old_checkpoint in self.list_checkpoints()[num_old_checkpoints:]:
                try:
                    os.remove(old_checkpoint)
                except Exception:
                    self.logger.exception('can not remove old checkpoint: %s', old_checkpoint)

    def add_summaries(self, tags, values, prefix, global_step=None):
        if not self.writer:
            return

        if global_step is None:
            global_step = self.global_step
        tags = [tag.replace(' ', '_') for tag in tags]
        if prefix != '':
            self.writer.add_scalars(prefix, dict(zip(tags, values)), global_step)
        else:
            for tag, value in zip(tags, values):
                self.writer.add_scalar(tag, value, global_step=global_step)

    def add_summary(self, tag, value, global_step=None):
        if not self.writer:
            return

        tag = tag.replace(' ', '_')
        if global_step is None:
            global_step = self.global_step
        self.writer.add_scalar(tag, value, global_step)

    def get_path(self, *name):
        return os.path.join(self.working_dir, *name)

    @classmethod
    def from_command_line(cls, entry_class,
                          argv=None, default_instance=None, abbrevs=None, **kwargs):
        entry_class = import_class(entry_class)

        parser = SingleOptionsParser()
        parser.setup(entry_class.Options,
                     abbrevs=abbrevs, default_instance=default_instance)

        pre_parser = ArgumentParser()
        pre_parser.add_argument('--restore', type=str, default=None,
                                help='restore training')
        pre_parser.add_argument('--no-confirm', action='store_true', default=False,
                                help='no confirm during restoring session')
        options, argv = pre_parser.parse_known_args(argv)

        kwargs.update(no_confirm=options.no_confirm)
        if options.restore is not None:
            options = options.restore
        else:
            options = parser.parse_args(argv)

        return cls(options, entry_class, **kwargs)
