# -*- coding: utf-8 -*-

import copy
import glob
import inspect
import json
import os
import pickle
import sys
from argparse import ArgumentParser, Namespace
from enum import Enum
from pprint import pformat
from typing import NewType, Union

import dataclasses
from dataclasses import MISSING, Field
from typeguard import check_type, get_type_name, isclass

from .logger import LOGGER, open_file
from .utils import DataClassMeta, DotDict, Singleton

BRANCH_SUFFIX = '_options'
TRAINING_KEY = '_OptionsBase__training'
META_KEY = '__ARGFIELD__'

REQUIRED = Singleton.create('Required')
IGNORED = Singleton.create('Ignored')
ExistFile = NewType("ExistFile", str)


def issubclass_safe(cls, base):
    return isclass(cls) and issubclass(cls, base)


def get_field_def(cls, key):
    return getattr(cls, dataclasses._FIELDS, {}).get(key)


def field_active_time(field_def):
    return field_def.metadata.get(META_KEY, {}).get('active_time')


def field_choices(field_def):
    return field_def.metadata.get(META_KEY, {}).get('choices')


class FieldActiveTime(Enum):
    Train = 'train'
    Predict = 'predict'
    Both = 'both'

    def merge(self, other):
        if other is None:
            return self

        if self is not other:
            return self.Both

        return self


DEFAULT_META = DotDict({
    'parser_names': None,
    'choices': None,
    'help': None,
    'metavar': None,
    'nargs': None,
    'active_time': FieldActiveTime.Both,
    'predict_default': MISSING
})


def _normalize_namesapce(namespace):
    if namespace and not namespace.endswith('.'):
        return namespace + '.'
    return namespace


def pretty_format(obj, indent=0):
    if dataclasses.is_dataclass(obj):
        if isinstance(obj, OptionsBase):
            return obj.pretty_format(indent)

        LOGGER.warning(f'{get_type_name(obj.__class__)} should inherent Options')
        obj = vars(obj)
    elif isinstance(obj, Namespace):
        obj = vars(obj)

    return pformat(obj, indent)


class OptionsMeta(DataClassMeta):
    def __new__(cls, name, bases, attrs, **kwargs):
        annotations = attrs.get('__annotations__')
        extra_fields = kwargs.pop('extra_fields', None)
        if extra_fields is not None:
            if annotations is None:
                annotations = extra_fields
            else:
                annotations.update(extra_fields)

        current_active_time = None
        active_time = kwargs.pop('active_time', None)
        if active_time is not None:
            active_time = FieldActiveTime(active_time)
        if annotations is not None:
            for attr_name, attr_type in annotations.items():
                field = attrs.get(attr_name, MISSING)

                is_options = issubclass_safe(attr_type, OptionsBase)
                if field is MISSING:
                    attrs[attr_name] = field = argfield()
                elif not isinstance(field, dataclasses.Field):
                    attrs[attr_name] = field = argfield(field)

                field.type = attr_type  # set type explicitly
                if is_options:
                    if field.default is REQUIRED and field.default_factory is MISSING:
                        field.default = MISSING
                        field.default_factory = attr_type

                meta = field.metadata[META_KEY]
                meta.active_time = meta.active_time.merge(active_time)
                if is_options:
                    meta.active_time = attr_type.ACTIVE_TIME.merge(meta.active_time)

                current_active_time = meta.active_time.merge(current_active_time)

        cls = super().__new__(cls, name, bases, attrs, **kwargs)
        cls.ACTIVE_TIME = current_active_time or FieldActiveTime.Train

        for base in bases:
            if issubclass(base, OptionsBase):
                cls.ACTIVE_TIME = cls.ACTIVE_TIME.merge(base.ACTIVE_TIME)
        return cls


class OptionsBase(metaclass=OptionsMeta):
    @property
    def training(self):
        return getattr(self, TRAINING_KEY, True)

    def train(self):
        self.__training = True
        self._apply_items(options_fn=lambda _, v: v.train())

    def eval(self):
        self.__training = False
        self._apply_items(options_fn=lambda _, v: v.eval())

    def state_dict(self, prefix='', allow_unset=False):
        destination = {}
        if prefix == '':
            destination[TRAINING_KEY] = self.training

        def _set(k, v):
            destination[k] = v

        def _check(k, v):
            assert allow_unset or v is not REQUIRED, \
                f'field "{prefix}{k}" of a option object should be set'

        self._apply_items(lambda k, v: _check(k, v) or _set(k, vars(v)),
                          lambda k, v: _check(k, v) or _set(k, v.state_dict(k + '.')),
                          lambda k, v: _check(k, v) or _set(k, v))

        return destination

    def _load_from_state_dict(self, state_dict, training):
        self.__training = training

        def _set(k, _):
            setattr(self, k, copy.deepcopy(state_dict.pop(k)))

        self._apply_items(_set,
                          (lambda k, v: v._load_from_state_dict(state_dict.pop(k), training)),
                          _set)

    def load_state_dict(self, state_dict):
        state_dict = state_dict.copy()

        self._load_from_state_dict(state_dict, state_dict.pop(TRAINING_KEY))

        assert not state_dict, f'{state_dict.keys()} are redundant'

    def diff_options(self, other):
        assert isinstance(other, self.__class__), 'Can not diff intances of different class'
        d1 = dict(self.iter_items())
        d2 = dict(other.iter_items())

        diffs = []
        for key, v1 in d1.items():
            v2 = d2.pop(key, dataclasses.MISSING)
            if isinstance(v1, OptionsBase) and isinstance(v2, OptionsBase):
                diffs.extend((f'{key}.' + diff[0], *diff[1:]) for diff in v1.diff_options(v2))
            elif v1 != v2:
                diffs.append((key, v1, v2))

        diffs.extend((key, dataclasses.MISSING, v2) for key, v2 in d2.items())
        return diffs

    def merge_options(self, other):
        assert isinstance(other, self.__class__), 'Can not diff intances of different class'
        for key, other_value in other.iter_items():
            if isinstance(other_value, OptionsBase):
                getattr(self, key).merge_options(other_value)
            else:
                setattr(self, key, other_value)

    def iter_fields(self, training=True):
        for field_def in dataclasses.fields(self):
            active_time = field_active_time(field_def)
            if active_time is FieldActiveTime.Train and not training:
                continue
            if active_time is FieldActiveTime.Predict and training:
                continue
            yield field_def

    def iter_items(self):
        for field_def in self.iter_fields(self.training):
            key = field_def.name
            yield key, getattr(self, key)

    def _apply_items(self, dataclass_fn=None, options_fn=None, default_fn=None):
        for field_def in self.iter_fields(self.training):
            key = field_def.name
            value = getattr(self, key)
            if dataclasses.is_dataclass(value):
                if isinstance(value, OptionsBase):
                    if options_fn is not None:
                        options_fn(key, value)
                else:
                    LOGGER.warning(f'{value.__class__.__qualname__} should inherent Options')
                    if dataclass_fn is not None:
                        dataclass_fn(key, value)
            else:
                if default_fn is not None:
                    default_fn(key, value)

    def has_attr(self, key):
        field_def = get_field_def(self, key)
        if field_def is not None:
            active_time = field_active_time(field_def)
            if (active_time is FieldActiveTime.Train and not self.training) or \
               (active_time is FieldActiveTime.Predict and self.training):
                return False
            return True
        return False

    def __setattr__(self, key, value):
        # ignore setattr from self
        current_frame = inspect.currentframe()
        outside_self = current_frame.f_back.f_locals.get('self') is not self
        if outside_self and not key.startswith('__'):
            field_def = get_field_def(self, key)
            if field_def is None:
                raise KeyError(f'{self.__class__.__qualname__} has no attribute \"{key}\"')
            self.check_key(field_def, value)

        ret = super().__setattr__(key, value)
        if outside_self and isinstance(value, OptionsBase):
            if self.training:
                value.train()
            else:
                value.eval()

        return ret

    def pretty_format(self, indent=0, top_level=True):
        fields1 = []
        fields2 = []

        self._apply_items(
            lambda k, v: fields2.append((k, pformat(vars(v)))),
            lambda k, v: fields2.append((k, v.pretty_format(indent + 2, False))),
            lambda k, v: (top_level or not k.startswith('_')) and fields1.append((k, repr(v))))

        fields1 += fields2
        indent_str = ' ' * indent
        ret = f'{self.__class__.__qualname__}(\n'
        ret += ',\n'.join('{}  {}={}'.format(indent_str, *field) for field in fields1)
        ret += f'\n{indent_str})'
        return ret

    @classmethod
    def check_key(cls, field_def, value):
        key = field_def.name
        metadata = field_def.metadata.get(META_KEY, {})
        choices = metadata.get('choices', None)
        if field_def.type is not None:
            check_type(key, value, field_def.type)

        if choices is not None and value not in choices:
            if hasattr(value, '__iter__') and any(x not in choices for x in value):
                choices = '{' + ','.join(choices) + '}'
                raise KeyError(f'Invalid value "{value}" for {cls.__qualname__}.{key}. '
                               f'Must chosen from {choices}')

    def to_file(self, path, logger=LOGGER):
        options = self.state_dict()
        try:
            with open_file(path, 'w') as fp:
                json.dump(options, fp, indent=2)
        except Exception:
            logger.warning('Can not save config to json format, use pickle instead.')
            with open_file(path, 'wb') as fp:
                pickle.dump(options, fp)
            with open_file(path + '.txt', 'w') as fp:
                fp.write(self.pretty_format())

    @classmethod
    def from_file(cls, path):
        try:
            with open_file(path, 'r') as fp:
                saved_state = json.load(fp)
        except json.JSONDecodeError:
            with open_file(path, 'rb') as fp:
                saved_state = pickle.load(fp)

        options = cls()
        options.load_state_dict(saved_state)
        return options


def _remove_none_value(kwargs):
    for k in [k for k, v in kwargs.items() if v is None]:
        del kwargs[k]


def _normalize_action(name, default):
    if default is None:
        raise Exception('bool option should has a default value')

    if default is True:
        action = 'store_false'
        if name.startswith('disable_'):
            name = f'enable_{name[8:]}'
        elif name.startswith('enable_'):
            name = f'disable_{name[7:]}'
        else:
            name = f'disable_{name}'
    else:
        action = 'store_true'

    return name, action


def _normalize_type(field_type, nargs, choices):
    arg_type = None
    if field_type is ExistFile:
        arg_type = str
    elif field_type in (list, set, frozenset):
        arg_type = str
        nargs = nargs or '*'
    elif getattr(field_type, '__origin__', None) is Union and \
            isinstance(None, field_type.__args__[1]):
        # type annotation like Optional[int]
        arg_type = field_type.__args__[0]
    elif hasattr(field_type, '__args__'):
        if len(field_type.__args__) == 1:
            # type annotation like List[int] or Union[str, int]
            arg_type = field_type.__args__[0]
            nargs = nargs or '*'
        elif str in field_type.__args__:
            arg_type = str
        # skip other type of Union
    elif choices is not None:
        arg_type = type(list(choices)[0])
    else:
        arg_type = field_type

    return arg_type, nargs


def _normalize_parser_names(name, parser_names):
    if parser_names is None:
        return [f'--{name}']

    if isinstance(parser_names, str):
        parser_names = [parser_names]

    if all(_.startswith('-') for _ in parser_names):
        if all(len(_) == 2 for _ in parser_names):  # abbrev
            parser_names.append(f'--{name}')
    else:
        assert len(parser_names) == 1, 'positional argument should have a unique name'

    for index, name in enumerate(parser_names):
        if not name.startswith('-'):  # name of positional arguments should be valid identity
            parser_names[index] = name.replace('-', '_')

    return parser_names


def _add_argument_to_parser(parser, field_def, default_instance, abbrevs, namespace, group):
    namespace = _normalize_namesapce(namespace)
    meta = field_def.metadata.get(META_KEY, DEFAULT_META)

    field_name = field_def.name
    field_type = field_def.type
    full_name = f'{namespace}{field_name}'
    kwargs = {}

    assert not isinstance(field_type, str), f'forward declaration is not allow. {full_name}'

    # default value
    default = getattr(default_instance, field_name, field_def.default)
    if not default_instance.training and meta.predict_default is not MISSING:
        default = meta.predict_default

    if default is REQUIRED or default is MISSING:
        if field_def.default_factory is MISSING and issubclass_safe(field_type, OptionsBase):
            field_def.default_factory = field_type
        try:
            default = field_def.default_factory()
        except Exception:
            pass

    required = default is REQUIRED or default is MISSING
    if required:
        default = None

    abbrevs = abbrevs.get(field_name) or abbrevs.get(field_name.replace('-', '_')) or {}
    if issubclass_safe(field_type, OptionsBase):
        _add_arguments_to_parser(parser, default, abbrevs=abbrevs, namespace=full_name)
        return

    help_str = meta.help if meta.help is not None else ''
    if default is not None:
        help_str += ' (default: %(default)s)'
    if meta.choices is not None:
        help_str += ' {%(choices)s}'
    if field_type is not None:
        help_str += f' <type: %(type)s>'

    action = None
    kwargs['dest'] = full_name
    if issubclass_safe(field_type, bool):
        field_name, action = _normalize_action(field_name, default)
        full_name = f'{namespace}{field_name}'

    field_type, nargs = _normalize_type(field_type, meta.nargs, meta.choices)

    parser_names = abbrevs or meta.get('parser_names')
    parser_names = _normalize_parser_names(full_name.replace('_', '-'), parser_names)

    kwargs.update({
        'default': default,
        'nargs': nargs,
        'action': action,
        'help': help_str,
    })

    if parser_names[0][0] == '-':
        kwargs['required'] = required
    else:
        kwargs.pop('dest')  # avoid to supply dest twice

    if not action:
        kwargs['metavar'] = meta.metavar or field_name.rsplit('.', 1)[-1]  # inner most name
        kwargs['type'] = field_type
        kwargs['choices'] = meta.choices

    _remove_none_value(kwargs)
    group.add_argument(*parser_names, **kwargs)


def _add_arguments_to_parser(parser, default_instance, abbrevs, namespace):
    title = f'{namespace}[{default_instance.__class__.__qualname__}]'
    group = parser.add_argument_group(title=title)
    namespace = _normalize_namesapce(namespace)

    # NOTE: here we need all fields (i.e. BranchSelect will hide inactive fields)
    for field_def in OptionsBase.iter_fields(default_instance, default_instance.training):
        _add_argument_to_parser(parser, field_def, default_instance,
                                abbrevs=abbrevs, namespace=namespace, group=group)


def argfield(default=REQUIRED, *, default_factory=MISSING,
             # argparse arguments
             help=None, choices=None, metavar=None, nargs=None, names=None,
             # when the field is active
             active_time='train', predict_default=MISSING,
             # the same arguments as dataclass.field
             init=True, repr=True, hash=None, compare=True, metadata=None):
    if default is REQUIRED and default_factory is not MISSING:
        default = MISSING

    metadata = metadata.copy() if metadata is not None else {}
    metadata[META_KEY] = DotDict({
        'parser_names': names,
        'choices': choices,
        'help': help,
        'metavar': metavar,
        'nargs': nargs,
        'active_time': FieldActiveTime(active_time),
        'predict_default': predict_default
    })

    return Field(default, default_factory, init, repr, hash, compare, metadata)


def read_config(options, config_path):
    visited_paths = set()

    def _load(path):
        if path in visited_paths:
            return

        visited_paths.add(path)
        cur_dir = os.path.dirname(path)

        def _include(base_path):
            _load(os.path.join(cur_dir, base_path))

        def _glob_one(name):
            paths = glob.glob(name)
            assert len(paths) == 1, f'find multiple or no result(s) of "{name}": {paths}'
            return paths[0]

        code = open(path).read()

        exec(compile(code, path, 'exec'),
             {'import_config': _include, '_glob': glob.glob, '_glob_one': _glob_one},
             options)

    _load(os.path.realpath(os.path.abspath(config_path)))
    return {key: value for key, value in options.items() if not key.startswith('_')}


def _merge_namespace_and_instance(namespace, instance, required_fields):
    def _get_value(sub_names, default=None, ignored=None):
        value = instance
        for sub_name in sub_names:

            if ignored is not None and \
               isinstance(value, BranchSelect) and value.is_inactive(sub_name):
                # in active branch in BranchSelect should be ignored in strict mode
                return ignored

            value = getattr(value, sub_name, default)
            if value is default:
                break

        return value

    def _get_scope(name):
        sub_names = name.split('.')
        return _get_value(sub_names[:-1]), sub_names[-1]

    extras = {}
    for name, value in vars(namespace).items():
        name = name.replace('-', '_')

        scope, sub_name = _get_scope(name)

        if hasattr(scope, sub_name):
            setattr(scope, sub_name, value)
        else:
            extras[name] = value

    required_fields = [name for name in required_fields
                       if _get_value(name.split('.'), REQUIRED, IGNORED) is REQUIRED]
    assert not required_fields, f'the following arguments are required: {required_fields}'

    return extras


class SingleOptionsParser(ArgumentParser):
    def __init__(self, *args, config_prefix='@@', is_subparser=False, **kwargs):
        super().__init__(*args, **kwargs)

        self.is_subparser = is_subparser
        self.config_path = None
        self.config_prefix = '@@'
        self.training = True

        self._reset()

    def _reset(self):
        self.instance = None
        self.extra_options = None

    def setup(self, options_class=None, default_instance=None, abbrevs=None, training=True):
        if options_class is None:
            options_class = default_instance.__class__
        assert issubclass_safe(options_class, OptionsBase)

        if default_instance is None:
            default_instance = options_class()

        self.default_state = default_instance.state_dict(allow_unset=True)

        self.training = training
        if training:
            default_instance.train()
        else:
            default_instance.eval()

        self.options_class = options_class

        _add_arguments_to_parser(self, default_instance, abbrevs=abbrevs or {}, namespace='')

    def parse_known_args(self, argv=None, namespace=None):
        self._reset()

        options, argv = super().parse_known_args(argv, namespace)
        # NOTE: check required_fields manualy
        required_fields = [action.dest for action in self._actions if action.required]

        instance = self.instance
        try:
            # merge options from command line
            self.extra_options.update(
                _merge_namespace_and_instance(options, instance, required_fields))
        except AssertionError as err:
            LOGGER.error('%s', err)
            sys.exit(1)

        self.extra_options = Namespace(**self.extra_options)
        if self.is_subparser:
            instance.__config_path__ = self.config_path
            instance.__options_class__ = self.options_class

        if self.training:
            instance.train()
        else:
            instance.eval()

        return instance, argv

    def _parse_known_args(self, arg_strings, namespace):
        instance = self.options_class()
        instance.load_state_dict(self.default_state)

        # keep default values
        self.extra_options = {key: value for key, value in vars(namespace).items()
                              if '.' not in key}

        namespace = Namespace()  # NOTE: dot not use default value of parser
        new_arg_strings = []
        index = 0
        while index < len(arg_strings):
            arg_string = arg_strings[index]
            index += 1
            # for regular arguments, just add them back into the list
            if not arg_string or not arg_string.startswith(self.config_prefix):
                new_arg_strings.append(arg_string)
                continue

            assert self.config_path is None, 'There should be only one config file.'
            # replace arguments referencing files with the file content
            config_path = arg_string[len(self.config_prefix):]
            if config_path == '':
                assert index < len(arg_strings)
                config_path = arg_strings[index]
                index += 1

            # merge options from config file
            instance = read_config(vars(instance), config_path)
            instance = self.options_class(**instance)
            self.config_path = config_path

        self.instance = instance

        required_actions = []
        for action in self._actions:
            if action.required:
                required_actions.append(action)
                action.required = False

        options, argv = super()._parse_known_args(new_arg_strings, namespace)

        for action in required_actions:  # restore
            action.required = True

        return options, argv


class BranchSelectMeta(OptionsMeta):
    def __new__(cls, name, bases, attrs, **kwargs):
        assert 'type' in attrs, 'default value of field "type" is required'
        assert 'branches' in attrs, 'default value of field "branches" is required'

        branches = attrs['branches']
        annotations = attrs.setdefault('__annotations__', {})
        annotations['type'] = str

        for branch_name, branch_item in branches.items():
            if isinstance(branch_item, (tuple, list)):
                branch_type, options_type = branch_item
            else:
                branch_type = branch_item
                options_type = branch_item.Options

            assert issubclass_safe(options_type, OptionsBase), \
                f'"{options_type}" should be a subclass of OptionsBase'

            key = branch_name + BRANCH_SUFFIX
            if key in attrs:
                continue

            attrs[key] = argfield(default_factory=options_type)
            annotations[key] = options_type
            branches[branch_name] = branch_type

        if OptionsBase not in bases:
            bases += (OptionsBase,)

        return super().__new__(cls, name, bases, attrs, **kwargs)


class BranchSelect(metaclass=BranchSelectMeta):
    type: str = argfield('', active_time='both')
    branches = {}

    def is_inactive(self, key):
        return key.endswith(BRANCH_SUFFIX) and \
            key[:-len(BRANCH_SUFFIX)] in self.branches and \
            key != self.type + BRANCH_SUFFIX

    def __getattribute__(self, key):
        # ignore getattribute from self
        current_frame = inspect.currentframe()
        outside_self = current_frame.f_back.f_locals.get('self') is not self
        if outside_self and self.is_inactive(key):
            raise AttributeError(f'current branch is {self.type}')

        return super().__getattribute__(key)

    def create(self, *args, **kwargs):
        branch_type = self.branches[self.type]

        branch_options = getattr(self, self.type + BRANCH_SUFFIX)
        if hasattr(branch_options, 'create'):
            return branch_options.create(*args, **kwargs)

        return branch_type(branch_options, *args, **kwargs)

    def iter_fields(self, training=True):
        for field_def in super().iter_fields(training):
            name = field_def.name
            if self.is_inactive(name):
                continue
            yield field_def
