# -*- coding: utf-8 -*-

import sys
from argparse import ArgumentParser
from io import BytesIO

import torch

from ..common.dataclass_options import SingleOptionsParser
from ..common.debug_console import DebugConsoleWrapper
from ..common.logger import open_file
from ..packer.magic_pack import pack_to_executable
from .utils import import_class, print_version


class PredictSession:
    def __init__(self, entry_class, model_path=None, model_bytes=None):
        self.model_bytes = model_bytes
        self.model_path = model_path
        self.entry_class = entry_class

    def load_saved_state(self):
        if self.model_bytes is None:
            return torch.load(self.model_path, map_location=torch.device('cpu'))
        saved_state = torch.load(BytesIO(self.model_bytes), map_location=torch.device('cpu'))

        self.model_bytes = None  # clear cache
        return saved_state

    def _run(self, options, saved_state):
        entry_object = self.entry_class(options)
        entry_object.setup(saved_state['object'])

        print(options.pretty_format())

        for path in options.test_paths:
            entry_object.evaluate_entry(path, mode='predict')

    def run(self, argv=None, abbrevs=None, use_debugger=True):
        print_version()

        saved_state = self.load_saved_state()

        options = self.entry_class.Options()
        options.load_state_dict(saved_state['options'])

        parser = SingleOptionsParser()
        # it's important to use `options` as default value for the options
        # used both in training and prediction
        parser.setup(default_instance=options, abbrevs=abbrevs, training=False)
        options.merge_options(parser.parse_args(argv))

        if options.use_debugger and use_debugger:
            DebugConsoleWrapper()(self._run, options, saved_state)
        else:
            self._run(options, saved_state)

    def pack(self, output_path, entry_point='main'):
        model_bytes = self.model_bytes
        if model_bytes is None:
            saved_state = torch.load(open_file(self.model_path, 'rb'),
                                     map_location=torch.device('cpu'))
            # clip redundant attributes for release
            saved_state['object'] = self.entry_class.make_release(saved_state['object'])

            output = BytesIO()
            torch.save(saved_state, output)
            model_bytes = output.getvalue()

        return pack_to_executable(model_bytes,
                                  output_path=output_path,
                                  entry_class=self.entry_class,
                                  entry_point=entry_point)

    @classmethod
    def entry_point(cls, entry_class,
                    model_bytes=None, argv=None, abbrevs=None, entry_point='main'):
        entry_class = import_class(entry_class)

        if model_bytes is not None:
            model_path = '<packed_file>'
            pack_path = None
        else:
            parser = ArgumentParser(add_help=False)
            parser.add_argument('--model', default=None, help='path to saved model.')
            parser.add_argument('--pack', default=None,
                                help='pack given model file instead of running it.')

            options, argv = parser.parse_known_args(argv)
            model_path = options.model
            pack_path = options.pack

            if model_path is None:
                parser.print_help()
                print('error: the following arguments are required', '--model')
                sys.exit(1)

        session = cls(entry_class, model_path, model_bytes)
        if pack_path:
            session.pack(pack_path, entry_point)
        else:
            session.run(argv, abbrevs=abbrevs)
