# -*- coding: utf-8 -*-

import io
import sys
import tempfile

import pyshrg
import SHRG.utils.eds_modified as eds
from framework.common.dataclass_options import SingleOptionsParser
from framework.common.logger import LOGGER, smart_open_file
from framework.common.utils import ProgressReporter
from nn_generator.evaluate.utils import Evaluator
from pyshrg_utils.parser import PYSHRG_ABBREVS, PySHRGOptions, pyshrg_init
from SHRG.graph_io import TEXT_WRITRES
from SHRG.shrg_extract import ExtractionOptions


class WeightedEvaluator(Evaluator):
    def __init__(self, return_output=False, verbose=False):
        super().__init__(None, None, None)

    def __call__(self, predict_cfg=True, gold_trees=None):
        return super().__call__(predict_cfg=False, gold_trees=None)

    def forward(self):
        manager = pyshrg.get_manager()
        error_count = 0

        progress = ProgressReporter(stop=manager.graph_size,
                                    message_fn=lambda _: f'error: {error_count}')

        context = manager.get_context(0)
        for graph_index, graph in progress(enumerate(manager.iter_graphs())):
            code = context.parse(graph_index)
            if code != pyshrg.ParserError.kNone:
                LOGGER.warning('>>> failed: %s %s', code, graph.sentence_id)
                error_count += 1
                continue

            self.generate(context, graph)


class MainOptions(PySHRGOptions):
    model_or_grammar_path: str
    config_path: str

    output_path: str = '-'
    input_path: str = '-'


def parse_args(argv=None):
    parser = SingleOptionsParser()
    default = MainOptions()
    default.num_contexts = 1
    default.verbose = False
    default.parser_type = 'tree_index_v2/terminal_first'

    abbrevs = PYSHRG_ABBREVS.copy()
    abbrevs.update(output_path='-o',
                   input_path='-i',
                   config_path='-C',
                   model_or_grammar_path='model_or_grammar_path')

    parser.setup(default_instance=default, abbrevs=abbrevs)

    return parser.parse_args(argv)  # type: MainOptions


def load_graphs(path):
    with smart_open_file(path, 'r') as fp:
        data = '\n'.join(filter(lambda line: not line.strip().startswith('#'), fp.readlines()))
        for eds_graph in eds.loads(data):
            eds_graph.sentence = '<unset>'
            yield eds_graph


def prepare_graphs(input_path, modify_label):
    writer = TEXT_WRITRES.normalize('eds')

    graphs = list(load_graphs(input_path))

    buffer = io.StringIO()
    buffer.write(f'{len(graphs)}\n')
    for index, eds_graph in enumerate(graphs):
        writer(buffer, f'#{index}', eds_graph, modify_label=modify_label)

    input_file = tempfile.NamedTemporaryFile()
    with open(input_file.name, 'w') as fp:
        fp.write(buffer.getvalue())

    # print(open(input_file.name).read())
    return input_file


def save_results(evaluator, output_path):
    def _print_trees(derivation_strings, fp):
        for sentence_id, tree in derivation_strings:
            print(sentence_id, file=fp)
            print(tree, file=fp)

    if output_path != '-':
        with smart_open_file(output_path + '.trees', 'w') as fp:
            _print_trees(evaluator.derivation_strings, fp)
    else:
        _print_trees(evaluator.derivation_strings, sys.stdout)

    if output_path != '-':
        evaluator.metrics.save(output_path, no_score=True)


def run_neural_network(options, input_path, output_path):
    import torch
    from framework.torch_extra.predict_session import PredictSession
    from nn_generator.feature_based.model import SHRGGenerator

    def evaluate(self, path, mode):
        assert pyshrg.get_manager().load_graphs(input_path)

        self.network.eval()
        with torch.no_grad():
            self.evaluator(predict_cfg=(self.options.train_mode != 'hrg'),
                           gold_trees=None)

        save_results(self.evaluator, output_path)

    SHRGGenerator.evaluate_entry = evaluate
    session = PredictSession(SHRGGenerator, options.model_or_grammar_path)
    session.run(['--test-paths', input_path,
                 '--output-prefix', output_path,
                 '--evaluator-mode', 'single',
                 '--pyshrg.parser-type', options.parser_type],
                use_debugger=False)


def run_weighted(options, input_path, output_path):
    manager = pyshrg_init(options.model_or_grammar_path, options=options)
    assert manager.load_graphs(input_path)

    evaluator = WeightedEvaluator()
    evaluator()
    save_results(evaluator, output_path)


def main(argv=None):
    options = parse_args(argv)

    if options.model_or_grammar_path.endswith('.pt'):
        run_fn = run_neural_network
    else:
        run_fn = run_weighted

    config = ExtractionOptions.from_file(options.config_path)
    input_file = prepare_graphs(options.input_path, config.modify_label)
    run_fn(options, input_file.name, options.output_path)


if __name__ == '__main__':
    main()
