import dataclasses
import sys
from pathlib import Path

import jsonlines
from tqdm import tqdm

from baseline.utils import convert_prediction_to_labels
from contrastive_conditioning.coverage_detector import CoverageDetector, CoverageResult
from data_generation.utils import WordLevelQETokenizer
from translation_models.mbart_models import MbartScoringModel

language_pair = sys.argv[1]

src_lang = language_pair.split("-")[0]
tgt_lang = language_pair.split("-")[1]

dataset_name = sys.argv[2]

evaluator_name = f"mbart-large-50-one-to-many"

data_dir = Path(__file__).parent.parent / "data" / "synthetic"

src_path = data_dir / (dataset_name + ".src")
if language_pair == "zh-en":
    src_path = data_dir / (dataset_name + ".src.cleaned.truncated")
tgt_path = data_dir / (dataset_name + ".mt")
assert src_path.exists()
assert tgt_path.exists()

out_jsonl_path = Path(__file__).parent.parent / "predictions" / (dataset_name + f".{evaluator_name}.jsonl")
out_source_tags_path = Path(__file__).parent.parent / "predictions" / (dataset_name + f".{evaluator_name}.source_tags")
out_target_tags_path = Path(__file__).parent.parent / "predictions" / (dataset_name + f".{evaluator_name}.tags")

if language_pair == "en-de":
    forward_evaluator = MbartScoringModel("facebook/mbart-large-50-one-to-many-mmt", src_lang="en_XX", tgt_lang="de_DE", device=0)
    backward_evaluator = MbartScoringModel("facebook/mbart-large-50-many-to-one-mmt", src_lang="de_DE", tgt_lang="en_XX", device=1)
elif language_pair == "zh-en":
    forward_evaluator = MbartScoringModel("facebook/mbart-large-50-many-to-one-mmt", src_lang="zh_CN", tgt_lang="en_XX", device=0)
    backward_evaluator = MbartScoringModel("facebook/mbart-large-50-one-to-many-mmt", src_lang="en_XX", tgt_lang="zh_CN", device=1)
else:
    raise NotImplementedError


detector = CoverageDetector(
    src_language=src_lang,
    tgt_language=tgt_lang,
    forward_evaluator=forward_evaluator,
    backward_evaluator=backward_evaluator,
    batch_size=(1 if language_pair == "zh-en" else 2),
)

src_tokenizer = WordLevelQETokenizer(src_lang)
tgt_tokenizer = WordLevelQETokenizer(tgt_lang)

with open(src_path) as f_src, open(tgt_path) as f_tgt, \
        open(out_source_tags_path, "w") as f_out_src, open(out_target_tags_path, "w") as f_out_tgt, \
        jsonlines.open(out_jsonl_path, "w") as f_out_jsonl:
    for src, tgt in zip(tqdm(f_src), f_tgt):
        src_tokens = src.strip().split()
        tgt_tokens = tgt.strip().split()
        src_detokenized = src_tokenizer.detokenize(src_tokens)
        tgt_detokenized = tgt_tokenizer.detokenize(tgt_tokens)
        result = detector.detect_errors(
            src=src_detokenized,
            translation=tgt_detokenized,
        )
        source_tags, target_tags = convert_prediction_to_labels(
            src_len=len(src_tokens),
            tgt_len=len(tgt_tokens),
            prediction=result,
            src_lang=src_lang,
            tgt_lang=tgt_lang,
        )
        f_out_src.write(" ".join(source_tags) + "\n")
        f_out_tgt.write(" ".join(target_tags) + "\n")
        f_out_jsonl.write({
            "src": src.strip(),
            "tgt": tgt.strip(),
            "src_detokenized": src_detokenized,
            "tgt_detokenized": tgt_detokenized,
            "prediction": dataclasses.asdict(result),
            "source_tags": source_tags,
            "target_tags": target_tags,
        })


src_out_path = out_jsonl_path.with_suffix(".source_tags")
tgt_out_path = out_jsonl_path.with_suffix(".tags")

with jsonlines.open(out_jsonl_path) as f_in, open(src_out_path, "w") as f_src, open(tgt_out_path, "w") as f_tgt:
    for i, line in enumerate(f_in):
        prediction = CoverageResult.from_dict(line["prediction"])
        src_labels, tgt_labels = convert_prediction_to_labels(
            src_len=len(line["src"].split()),
            tgt_len=len(line["tgt"].split()),
            prediction=prediction,
            src_lang=src_lang,
            tgt_lang=tgt_lang,
        )
        f_src.write(" ".join(src_labels) + "\n")
        f_tgt.write(" ".join(tgt_labels) + "\n")
