# -*- coding: utf-8 -*-

import sys

from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu

from framework.common.logger import open_wrapper


class BLEU:
    def __init__(self, postprocess_fn=None):
        self.hypotheses = []
        self.references = []
        self.sentence_ids = []
        self.smoothing_function = SmoothingFunction()

        self.postprocess_fn = postprocess_fn

    def add(self, hypothesis, reference, sentence_id):
        if isinstance(hypothesis, str):
            hypothesis = hypothesis.strip().split()
        if isinstance(reference, str):
            reference = reference.strip().split()

        if reference and isinstance(reference[0], str):
            reference = [reference]

        if self.postprocess_fn is not None:
            hypothesis = self.postprocess_fn(hypothesis)

        self.hypotheses.append(hypothesis)
        self.references.append(reference)
        self.sentence_ids.append(sentence_id)

    def print_sample(self, index=-1, stream=sys.stdout):
        print('###', self.sentence_ids[index])
        for ref_sentence in self.references[index]:
            print('<<<', *ref_sentence)
        print('>>>', *self.hypotheses[index])

    def get(self):
        return corpus_bleu(self.references, self.hypotheses,
                           smoothing_function=self.smoothing_function.method3)

    def __iter__(self):
        return zip(self.sentence_ids, self.hypotheses, self.references)

    def __len__(self):
        return len(self.sentence_ids)

    def save(self, prefix, sort_by_id=False):
        _open = open_wrapper(lambda x: prefix + x)
        suffixes = ['.hyp', '.ref', '.ref+hyp']
        with _open('.hyp', 'w') as fp_hyp, \
                _open('.ref', 'w') as fp_ref, \
                _open('.ref+hyp', 'w') as fp_out:
            results = list(self)
            if sort_by_id:
                results.sort(key=lambda x: x[0])

            for sentence_id, hyp, refs in results:
                print(sentence_id, *hyp, file=fp_hyp)
                for ref in refs:
                    print(sentence_id, *ref, file=fp_ref)

                print('###', sentence_id, file=fp_out)
                for ref in refs:
                    print('<<<', *ref, file=fp_out)
                print('>>>', *hyp, file=fp_out)

        return [prefix + suffix for suffix in suffixes]

    def split_outputs(self, batch_samples, outputs):
        pass

    def write_outputs(self, output_path, _, system_samples):
        pass
