# -*- coding: utf-8 -*-

import code
import os
import sys
import time
import traceback
from pprint import pprint

from .logger import LOGGER
from .timeout import TimeoutError, timeout
from .utils import Singleton

NoReturn = Singleton.create('NoReturn')


def _frame_at(exc_info, frame_level):
    iterator = traceback.walk_tb(exc_info[2])
    next(iterator)  # skip frame in debug_console
    return list(iterator)[frame_level][0]


def _locals_at(exc_info, frame_level):
    return _frame_at(exc_info, frame_level).f_locals


class DebugConsoleWrapper:
    def __init__(self, project_root=None):
        self.project_root = project_root

        self._start_time = None

    def _get_project_root(self):
        project_root = self.project_root
        if project_root is None:
            main_module = sys.modules.get('__main__')
            project_root = os.path.dirname(getattr(main_module, '__file__', None))

        if project_root is None:
            project_root = os.curdir

        return os.path.abspath(project_root)

    def _parse_frame_level(self, args, message, return_args=False):
        args = args.strip().split()
        value = -1
        if len(args) == 1:
            try:
                value = int(args[0])
            except ValueError:
                print(message)
                value = None

        if return_args:
            return value, args
        return value

    def _get_help_text(self):
        if self._help_text is not None:
            return self._help_text

        help_text = []
        for method_name, value in vars(self.__class__).items():
            if callable(value) and method_name.startswith('run_'):
                doc = value.__doc__ or ''
                command_name = method_name[4:]
                help_text.append(f'+ {command_name} {doc}\n')

    @timeout(900)
    def read_command(self):
        input_cmd = input('>>> ')
        choice, _, args = input_cmd.partition(' ')
        return choice, args

    def run_reload(self, *_):
        try:
            import pydevd_reload
            import pathlib
            import gc

            LOGGER.info('Reloading modules...')

            project_root = self._get_project_root()
            for module_name, module in sys.modules.items():
                if module_name == '__main__':
                    continue

                module_file = getattr(module, '__file__', None)
                if module_file is None or not module_file.startswith(project_root):
                    continue

                modify_time = pathlib.Path(module_file).stat().st_mtime

                if modify_time > self._start_time:
                    if pydevd_reload.xreload(module):
                        LOGGER.info(module_name + ' updated')

            gc.collect()
            return True
        except ModuleNotFoundError:
            LOGGER.error('"pydevd_reload" is not installed !!')

    def run_raise(self, *_):
        raise

    def run_console(self, exc_info, args):
        '''console [int:frame_level]- keep exception stack and start a console'''
        frame_level = self._parse_frame_level(args, 'console [int:frame_level]')
        if frame_level is None:
            return
        try:
            try:
                from IPython.terminal.embed import InteractiveShellEmbed
                InteractiveShellEmbed().mainloop(local_ns=_locals_at(exc_info, frame_level))
            except ModuleNotFoundError:
                code.interact(local=_locals_at(exc_info, frame_level))
        except SystemExit:
            return  # goto choice

    def run_pdb(self, exc_info, args):
        try:
            try:
                import ipdb
                ipdb.post_mortem(exc_info[2])
            except ModuleNotFoundError:
                import pdb
                pdb.post_mortem(exc_info[2])
        except EOFError:
            LOGGER.info('Exit pdb.')

    def run_print(self, exc_info, args):
        message = 'print [int:frame_level] [key]'
        frame_level, args = self._parse_frame_level(args, message, return_args=True)
        if frame_level is None:
            return

        if len(args) <= 1:
            message = _locals_at(exc_info, frame_level).keys()
        elif len(args) == 2:
            key = args[1]
            message = _locals_at(exc_info, frame_level)[key]

        print(message)

    def run_extract_tb(self, exc_info, _):
        '''extract_tb - print traceback frame summary'''
        pprint(traceback.extract_tb(exc_info[-1]))

    def run_help(self, *_):
        LOGGER.info('PID: %s', os.getpid())
        LOGGER.info('Exception occurred. What do you want to do ?\n'
                    '+ console [int:frame_level]- keep exception stack and start a console\n'
                    '+ extract_tb - print traceback frame summary\n'
                    '+ print [int:frame_level] [key] - print local variables in exception stack\n'
                    '+ pdb - set trace for pdb \n'
                    '+ raise - reraise exception and exit\n'
                    '+ exit - exit program')

    def _handle_exception(self):
        first_time = True
        while True:
            try:
                if first_time:
                    LOGGER.info('Wait for 1 seconds...')
                    time.sleep(1)
                    first_time = False

                # save exc_info to variable to prevent future exception
                exc_info = sys.exc_info()
                choice, args = self.read_command()

                if choice == 'exit':
                    return None

                command = getattr(self, 'run_' + choice, None)
                if command is None:
                    self.run_help()
                    continue
                try:
                    if command(exc_info, args):
                        break
                except Exception:
                    LOGGER.exception('??? internal error')
            except KeyboardInterrupt:
                if not first_time:
                    print('\nUse "Ctrl-D" or "exit" to exit debugger')
            except EOFError:
                return None
            except TimeoutError:
                LOGGER.critical('Timeout, exit now...')
                return None

        return NoReturn

    def __call__(self, fn, *args, **kwargs):
        ret = NoReturn
        while ret is NoReturn:
            self._start_time = time.time()
            try:
                ret = fn(*args, **kwargs)
            except Exception:
                # handle errors
                traceback.print_exc()
                # waiting for stderr flush
                sys.stderr.flush()

                ret = self._handle_exception()
            except KeyboardInterrupt:
                print('\nInterrupted')
                break
        return ret
