
import json
import pickle
from pathlib import Path
from typing import Tuple, Optional

import numpy as np

from datasets.base import BinaryHuman, MetricScores


DATA_ROOT = Path(__file__).parents[1] / "data"
SPOT_THE_BOT = DATA_ROOT / "spot_the_bot"
CACHE = DATA_ROOT / "cache"


def _binarize_annotation(annotations: Tuple[Optional[bool]]) -> Optional[bool]:
    filtered = [a for a in annotations if a is not None]
    if len(filtered) == 0:
        return None
    else:
        return all(filtered)


def load_annotated():
    cached = CACHE / "stb-annotated.pkl"
    if cached.exists():
        with cached.open('rb') as fin:
            result = pickle.load(fin)
        return result

    with (SPOT_THE_BOT / "spot_the_bot_gen.json").open("r") as fin:
        raw = json.load(fin)

    result = {}
    for domain, data in raw.items():
        result[domain] = {
            "human": {},
            "metric": {},
        }

        for item in data:
            bot_name = item['system_name']
            human_label = _binarize_annotation(item['human_judgments'])

            if human_label is None:
                continue

            if result[domain]['human'].get(bot_name) is None:
                result[domain]['human'][bot_name] = []

            result[domain]['human'][bot_name].append(human_label)

            for metric, score in item['metric_scores'].items():
                if result[domain]['metric'].get(metric) is None:
                    result[domain]['metric'][metric] = {}

                if result[domain]['metric'][metric].get(bot_name) is None:
                    result[domain]['metric'][metric][bot_name] = []

                result[domain]['metric'][metric][bot_name].append(score)

    output = {}
    for domain, data in result.items():
        output[domain] = {
            'human': {
                bot_name: BinaryHuman(
                    system=bot_name,
                    dataset=f"stb-{domain}",
                    binary_scores=np.array(binary, dtype=np.bool),
                )
                for bot_name, binary in result[domain]['human'].items()
            },
            "metric": {
                metric: {
                    bot_name: MetricScores(
                        metric=metric,
                        system=bot_name,
                        dataset=f"stb-{domain}",
                        scores=np.array(scores),
                    )
                    for bot_name, scores in metric_dict.items()
                }
                for metric, metric_dict in result[domain]['metric'].items()
            },
        }

    with cached.open('wb') as fout:
        pickle.dump(obj=output, file=fout)

    return output


def load_additional():
    cached = CACHE / "stb-additional.pkl"
    if cached.exists():
        with cached.open('rb') as fin:
            result = pickle.load(fin)
        return result

    with (SPOT_THE_BOT / "bot_bot_aj_scored.json").open("r") as fin:
        raw = json.load(fin)

    result = {}
    for domain, data in raw.items():
        result[domain] = {}

        for item in data:
            bot_name = item['system_name']

            for metric, score in item['metric_scores'].items():
                if result[domain].get(metric) is None:
                    result[domain][metric] = {}

                if result[domain][metric].get(bot_name) is None:
                    result[domain][metric][bot_name] = []

                result[domain][metric][bot_name].append(score)

    output = {
        domain: {
            metric: {
                bot_name: MetricScores(
                    metric=metric,
                    system=bot_name,
                    dataset=f"stb-{domain}-additional",
                    scores=np.array(scores),
                )
                for bot_name, scores in metric_dict.items()
            }
            for metric, metric_dict in domain_data.items()
        }
        for domain, domain_data in result.items()
    }

    with cached.open('wb') as fout:
        pickle.dump(obj=output, file=fout)

    return output
