import dataclasses
import sys
from pathlib import Path

import jsonlines
from tqdm import tqdm

from contrastive_conditioning.coverage_detector import CoverageDetector
from evaluation.utils import MqmDataset, MqmSample
from translation_models.mbart_models import MbartScoringModel

language_pair = sys.argv[1]

src_lang = language_pair.split("-")[0]
tgt_lang = language_pair.split("-")[1]

evaluator_name = f"mbart-large-50-one-to-many"

dataset = MqmDataset(language_pair)

out_path = Path(__file__).parent.parent / "predictions" / (str(dataset) + f".{evaluator_name}.jsonl")

if language_pair == "en-de":
    forward_evaluator = MbartScoringModel("facebook/mbart-large-50-one-to-many-mmt", src_lang="en_XX", tgt_lang="de_DE", device=0)
    backward_evaluator = MbartScoringModel("facebook/mbart-large-50-many-to-one-mmt", src_lang="de_DE", tgt_lang="en_XX", device=1)
elif language_pair == "zh-en":
    forward_evaluator = MbartScoringModel("facebook/mbart-large-50-many-to-one-mmt", src_lang="zh_CN", tgt_lang="en_XX", device=0)
    backward_evaluator = MbartScoringModel("facebook/mbart-large-50-one-to-many-mmt", src_lang="en_XX", tgt_lang="zh_CN", device=1)
else:
    raise NotImplementedError


detector = CoverageDetector(
    src_language=src_lang,
    tgt_language=tgt_lang,
    forward_evaluator=forward_evaluator,
    backward_evaluator=backward_evaluator,
    batch_size=1,
)

with jsonlines.open(out_path, "w") as f:
    for sample in tqdm(dataset.load_samples(load_original_sequences=True)):
        sample: MqmSample = sample
        result = detector.detect_errors(
            src=sample.original_source,
            translation=sample.original_target,
        )
        f.write({
            "sample": dataclasses.asdict(sample),
            "prediction": dataclasses.asdict(result),
        })
