from dataclasses import dataclass, asdict

from datasets import load_metric
from sacrebleu import sacrebleu, corpus_bleu, corpus_chrf


def load_tatoeba_test(src, tgt):
    sentences = []
    references = []
    with open(f'data/test/{src}-{tgt}.txt') as f:
        for line in f:
            _, _, src_sent, tgt_sent = line.strip().split('\t')
            sentences.append(src_sent)
            references.append(tgt_sent)

        return sentences, references


def load_raw_test_file(src, tgt):
    with open(f'data/test/raw/{src}-{tgt}.txt') as f:
        it = map(lambda e: e.strip(), f)

        sentences = []
        references = []
        predictions = []

        while True:
            try:
                sentences.append(next(it))
                references.append(next(it))
                predictions.append(next(it))
                assert len(next(it)) == 0
            except StopIteration:
                break

        assert len(sentences) == len(references) == len(predictions)
        return sentences, references, predictions


bert_score = load_metric('bertscore')


@dataclass
class EvaluationResult:
    bleu: float
    chrf: float
    bert: float

    def as_dict(self):
        return asdict(self)

    def __iter__(self):
        return iter(self.as_dict().items())

    def metrics_str(self):
        return '; '.join(map(lambda e: f"{e[0].upper()}={round(e[1], 3)}", self))

    def __str__(self):
        return f"Evaluation[{self.metrics_str()}]"

    def __repr__(self):
        return self.__str__()


def evaluate(predictions, references, lang=None) -> EvaluationResult:
    if lang == 'zh':
        tokenize = 'zh'
    elif lang == 'ja':
        tokenize = 'ja-mecab'
    else:
        tokenize = sacrebleu.DEFAULT_TOKENIZER

    bleu = corpus_bleu(predictions, [references], tokenize=tokenize, force=True)
    chrf = corpus_chrf(predictions, [references])
    # bs_out = bert_score.compute(
    #     predictions=predictions,
    #     references=references,
    #     lang=lang,
    #     device='cpu'
    # )
    #
    # bs = sum(bs_out['f1']) / len(predictions)
    bs = -1

    return EvaluationResult(bleu.score, chrf.score, bs)
