# -*- coding: utf-8 -*-

import functools
import glob
import gzip
import multiprocessing as mp
import os
from abc import ABCMeta, abstractclassmethod

from framework.common.dataclass_options import OptionsBase, argfield
from framework.common.logger import LOGGER
from framework.common.utils import ProgressReporter

from .graph_io import READERS

# TODO: modify shrg_extract.py to use ReaderBase


@functools.lru_cache()
def load_cfg_strings(prefix):
    cfg_strings = {}
    for tag in ('train', 'dev', 'test'):
        java_out_dir = prefix + tag
        for bank in os.listdir(java_out_dir):
            if not bank.startswith('wsj'):
                continue
            with open(os.path.join(java_out_dir, bank)) as fin:
                while True:
                    sentence_id = fin.readline().strip()
                    if not sentence_id:
                        break
                    assert sentence_id.startswith('#')
                    sentence_id = sentence_id[1:]
                    tree_string = fin.readline().strip()
                    cfg_strings[sentence_id] = bank, tree_string

    return cfg_strings


class ReaderBase(metaclass=ABCMeta):
    class Options(OptionsBase):
        graph_type: str = argfield(default='eds')

    def __init__(self, options: Options, data_path, split_patterns, logger=LOGGER, **extra_args):
        self.logger = logger
        self.options = options

        self.extra_args = extra_args

        self._data = {}
        self._splits = {}

        self.build_splits(data_path, split_patterns)

    def build_splits(self, data_path, split_patterns):
        for split, pattern in split_patterns:
            if isinstance(pattern, str):
                pattern = [pattern]
            dirs = sum((glob.glob(os.path.join(data_path, p)) for p in pattern), [])
            files = [glob.glob(os.path.join(dir, '*.gz')) for dir in dirs]

            self.logger.info('SPLIT %s: %d directories, %d files',
                             split, len(dirs), sum(map(len, files)))
            self._splits[split] = files

    def on_error(self, filename, error):
        pass

    def get_split(self, split, num_workers=-1, training=True):
        def _iter_results():
            if num_workers == 1:
                yield from map(self._worker, all_options)
            else:
                pool = None
                try:
                    pool = mp.Pool(num_workers)
                    yield from pool.imap_unordered(self._worker, all_options)
                finally:
                    if pool is not None:
                        pool.terminate()

        if num_workers == -1:
            num_workers = max(8, mp.cpu_count())

        data = self._data.get(split)
        if data is None:
            all_options = [(files, self.options, training, self.extra_args)
                           for files in self._splits[split]]
            data = {}
            progress = ProgressReporter(len(all_options), step=1)
            for outputs in progress(_iter_results()):
                for is_ok, filename, output in outputs:
                    if not is_ok:
                        self.on_error(filename, output)
                        continue

                    sample_id = os.path.basename(filename).split('.')[0]
                    data[sample_id] = output

            self._data[split] = data
        return data

    @abstractclassmethod
    def build_graph(cls, reader_output, filename, options, training):
        pass

    @classmethod
    def read_graph(cls, options, filename, read_fn, content):
        return read_fn(content.strip().split('\n\n'), options)

    @classmethod
    def _worker(cls, args):
        files, options, training, extra_args = args
        read_fn = READERS.normalize(options.graph_type)

        outputs = []
        for filename in files:
            with gzip.open(filename, 'rb') as fp:
                try:
                    reader_output = cls.read_graph(options, filename, read_fn, fp.read().decode())
                    output = cls.build_graph(reader_output, filename, options, training, extra_args)
                    outputs.append((True, filename, output))
                except Exception as err:
                    LOGGER.exception('%s', filename)
                    outputs.append((False, filename, err))

        return outputs
