
from pathlib import Path
from typing import List, Dict
import pickle
import csv
import json

import numpy as np

from datasets.base import PartialBinaryHuman, MetricScores


DATA_ROOT = Path(__file__).parents[1] / "data"
WMT21_NEWS = DATA_ROOT / "mt-metrics-eval" / "wmt21.news"
CACHE = DATA_ROOT / "cache"


def best_ref(lang_pair="en-de"):
    if lang_pair == "en-de":
        return "ref-C"
    else:
        raise ValueError(f"you provided unsupported language pair: {lang_pair}")


def load_human(lang_pair="en-de") -> Dict[str, PartialBinaryHuman]:
    cached = CACHE / "wmt21_human.pkl"
    if cached.exists():
        with cached.open("rb") as fin:
            obj = pickle.load(fin)
        return obj

    annotation_file = WMT21_NEWS / "human-scores" / f"{lang_pair}.mqm.seg.score"
    with annotation_file.open('r') as fin:
        reader = csv.DictReader(
            fin,
            fieldnames=['mt-system', "score"],
            delimiter='\t',
            quoting=csv.QUOTE_NONE,
        )
        data: List[dict] = list(reader)

    for d in data:
        if d['score'] == "None":
            d['score'] = None
        else:
            d['score'] = float(d['score'])

    mt_systems = {d['mt-system'] for d in data}

    result = {}

    for system in mt_systems:
        raw = [d['score'] for d in data if d['mt-system'] == system]

        scores = []
        binary_scores = []
        indices = []

        for ix, s in enumerate(raw):
            if s is not None:
                scores.append(s)
                binary_scores.append(s == 0.)
                indices.append(ix)

        result[system] = PartialBinaryHuman(
            system=system,
            dataset="wmt21",
            binary_scores=np.array(binary_scores, dtype=np.bool),
            indices=np.array(indices, dtype=np.int),
        )

    with cached.open("wb") as fout:
        pickle.dump(obj=result, file=fout)

    return result


def load_metric(lang_pair="en-de") -> Dict[str, Dict[str, MetricScores]]:
    cached = CACHE / "wmt21_metrics.pkl"
    if cached.exists():
        with cached.open('rb') as fin:
            res = pickle.load(fin)
        return res

    result = {}
    score_folder = WMT21_NEWS / "metric-scores" / f"{lang_pair}"
    for score_file in score_folder.glob("*.seg.score"):
        metric_name = score_file.name.split('.')[0]

        # only load referenced metrics using the best reference
        # described in the paper
        if 'ref' in metric_name and (not metric_name.endswith(best_ref(lang_pair))):
            continue

        with score_file.open('r') as fin:
            reader = csv.DictReader(
                fin,
                fieldnames=['mt-system', 'score'],
                delimiter="\t",
                quoting=csv.QUOTE_NONE,
            )
            data: List[dict] = list(reader)

        for d in data:
            d['score'] = float(d['score'])

        mt_systems = {d['mt-system'] for d in data}
        result[metric_name] = {
            mt_system: MetricScores(
                metric=metric_name,
                dataset="wmt21",
                system=mt_system,
                scores=np.array([d['score'] for d in data if d['mt-system'] == mt_system])
            )
            for mt_system in mt_systems
        }

    result['Facebook-AI-Loss'] = load_fb_scored()

    with cached.open('wb') as fout:
        pickle.dump(obj=result, file=fout)

    return result


def load_fb_scored():
    p = DATA_ROOT / "wmt21_additional" / "wmt21_en-de_scored.json"
    with p.open('r') as fin:
        data = json.load(fin)

    scores = {
        mt_system: MetricScores(
            metric="Facebook-AI-Loss",
            system=mt_system,
            dataset="wmt21",
            scores=-np.array([pair['score2'] for pair in pairs]),
        )
        for mt_system, pairs in data.items()
    }

    return scores
