# -*- coding: utf-8 -*-

import json
import pickle
from abc import abstractclassmethod
from ast import literal_eval
from typing import IO, Any, Dict, Union

import dataclasses

from ..common.logger import LOGGER, smart_open_file
from ..common.utils import DataClassMeta, Singleton

DELETE = Singleton.create('DELETE')
SKIP = Singleton.create('SKIP')


class DataFeatures(metaclass=DataClassMeta):
    original_index: int = -1
    original_object: Any = None
    attrs: Dict[str, Any] = dataclasses.field(default_factory=dict)

    @classmethod
    def run_plugins_for_sample(self, sample, plugins, **kwargs):
        if not plugins:
            return
        for plugin in plugins:
            plugin.postprocess_sample(sample, **kwargs)

    @classmethod
    def run_plugins_for_batch(self, batch_samples, feed_dict, plugins, **kwargs):
        if not plugins:
            return
        for plugin in plugins:
            plugin.postprocess_batch(batch_samples, feed_dict, **kwargs)

    @abstractclassmethod
    def create(cls, original_index, original_object, plugins, **kwargs):
        raise NotImplementedError

    @abstractclassmethod
    def pack_to_batch(cls, batch_samples, plugins):
        raise NotImplementedError


class DataSample(metaclass=DataClassMeta):
    COMMONT_PREFIX = '# '

    attrs: Dict[str, Any] = dataclasses.field(default_factory=dict)

    def __getattr__(self, name):
        if name == 'attrs' or name not in self.attrs:
            raise AttributeError(name)
        return self.attrs[name]

    def update_attrs(self, **new_attrs):
        attrs = vars(self)
        extra_attrs = self.attrs
        for key, value in new_attrs.items():
            if value is DELETE:
                if key in extra_attrs:
                    del extra_attrs[key]
            elif key in attrs:
                attrs[key] = value
            else:
                extra_attrs[key] = value

    def copy_attrs(self, **new_attrs):
        kwargs = vars(self).copy()
        kwargs['attrs'] = kwargs['attrs'].copy()

        sample = self.__class__(**kwargs)
        sample.update_attrs(**new_attrs)
        return sample

    @classmethod
    def from_pickle_binary_file(cls, fp):
        return pickle.load(fp)

    @classmethod
    def to_pickle_binary_file(cls, fp, objects):
        pickle.dump(objects, fp)

    @classmethod
    def from_json_file(cls, fp):
        items = []
        for item in json.load(fp):
            items.append(cls(**item))
        return items

    @classmethod
    def to_json_file(cls, fp, objects):
        fields = dataclasses.fields(cls)
        objects = [{field.name: getattr(item, field.name) for field in fields}
                   for item in objects]
        return json.dump(objects, fp, indent=2)

    @classmethod
    def from_standard_file(cls, fp, **kwargs):
        objects = []

        attrs = {}
        try:
            for index, line in enumerate(fp):
                if line.startswith(cls.COMMONT_PREFIX):
                    line = line[len(cls.COMMONT_PREFIX):]
                    key, value = map(str.strip, line.split(':', 1))
                    attrs[key] = literal_eval(value)
                else:
                    obj = cls.from_string(index, line, attrs, **kwargs)
                    if obj is SKIP:
                        continue

                    obj.attrs.update(attrs)
                    objects.append(obj)
                    attrs = {}
        except NotImplementedError:
            raise
        except Exception:
            LOGGER.exception('Wrong format at line %s', index)
            raise

        return objects

    @classmethod
    def to_standard_file(cls, fp, objects, **kwargs):
        fp.write(cls.file_header())
        for obj in objects:
            for key in sorted(obj.attrs.keys()):
                fp.write(f'{cls.COMMONT_PREFIX}{key}: {repr(obj.attrs[key])}\n')
            fp.write(obj.to_string(**kwargs) + '\n')

    @classmethod
    def from_string(cls, index, string, attrs, **kwargs):
        raise NotImplementedError

    def to_string(self, **kwargs):
        raise NotImplementedError

    @classmethod
    def file_header(cls):
        return ''

    @classmethod
    def from_file(cls, file_name_or_handle: Union[str, IO[str]],
                  input_format='standard', **kwargs):
        mode = 'r'
        fn = getattr(cls, f'from_{input_format}_file', None)
        if fn is None:
            mode = 'rb'
            fn = getattr(cls, f'from_{input_format}_binary_file')

        with smart_open_file(file_name_or_handle, mode) as fp:
            return fn(fp, **kwargs)

    @classmethod
    def to_file(cls, file_name_or_handle: Union[str, IO[str]], objects,
                output_format='standard', **kwargs):
        mode = 'w'
        fn = getattr(cls, f'to_{output_format}_file', None)
        if fn is None:
            mode = 'wb'
            fn = getattr(cls, f'to_{output_format}_binary_file')

        with smart_open_file(file_name_or_handle, mode) as fp:
            fn(fp, objects, **kwargs)

    @classmethod
    def internal_evaluate(cls, gold_items, system_items, log_file):
        raise NotImplementedError

    @classmethod
    def external_evaluate(cls, gold_file, system_file, log_file):
        return cls.internal_evaluate(cls.from_file(gold_file),
                                     cls.from_file(system_file), log_file)


def iter_batches(iterable, batch_size):
    index = 0
    batch_index = 0

    batch = []
    for item in iterable:
        batch.append(item)

        if len(batch) >= batch_size:
            yield index, batch_index, batch
            batch_index += 1
            index += len(batch)
            batch = []

    if batch:
        yield index, batch_index, batch


def iter_sub_batches(iterable, batch_size, batch_sub_size):
    iterator = iter(iterable)
    ret = []
    total_size = 0
    total_sub_size = 0

    while True:
        try:
            sent = next(iterator)
        except StopIteration:
            break
        ret.append(sent)
        total_size += 1
        total_sub_size += len(sent)
        if total_size >= batch_size or total_sub_size >= batch_sub_size:
            yield ret
            ret = []
            total_size = 0
            total_sub_size = 0

    if ret:
        yield ret
