"""
Loads and preprocesses a dataset of your choice, saving it in a normalized format,
ready to apply downstream substitution functions.
"""
import argparse
import gzip
import json
import typing

import wget

from src.classes.qadataset import *
from src.utils import argparse_str2bool

# NB: Feel free to add custom datasets here.
DATASETS = {
    # MRQA Datasets available at: https://github.com/mrqa/MRQA-Shared-Task-2019
    "MRQANaturalQuestionsTrain": (
        MRQANaturalQuetsionsDataset,
        "https://s3.us-east-2.amazonaws.com/mrqa/release/v2/train/NaturalQuestionsShort.jsonl.gz",
    ),
    "MRQANaturalQuestionsDev": (
        MRQANaturalQuetsionsDataset,
        "https://s3.us-east-2.amazonaws.com/mrqa/release/v2/dev/NaturalQuestionsShort.jsonl.gz",
    ),
    "NQTOP1": (
        NQTop1Dataset,
        "../NQ/ret_perf_data/dev_top1_only_toppas.json",
    ),
    "NQTOP1Train8000": (
        NQTop1Dataset,
        "../NQ/train8000_top1_only_toppas.json",
    ),
    "NQ": (
        NQTop1Dataset,
        "/data/timchen0618/open_domain_data/NQ/dev.json",
    ),
    "NQTOP25": (
        NQTop25Dataset,
        "/data/timchen0618/open_domain_data/NQ/dev.json",
    ),
    "NQTOP50": (
        NQTop50Dataset,
        "/data/timchen0618/open_domain_data/NQ/dev.json",
    ),
    "NQTOP75": (
        NQTop75Dataset,
        "/data/timchen0618/open_domain_data/NQ/dev.json",
    ),
    "NQTOP100": (
        NQTop100Dataset,
        "/data/timchen0618/open_domain_data/NQ/dev.json",
    ),

    "NQTOP25SHU": (
        NQTop25ShuDataset,
        "/data/timchen0618/open_domain_data/NQ/dev_shu_gold_higher.json",
    ),
    "NQTOP50SHU": (
        NQTop50ShuDataset,
        "/data/timchen0618/open_domain_data/NQ/dev_shu_gold_higher.json",
    ),
    "NQTOP75SHU": (
        NQTop75ShuDataset,
        "/data/timchen0618/open_domain_data/NQ/dev_shu_gold_higher.json",
    ),
    "NQTOP100SHU": (
        NQTop100ShuDataset,
        "/data/timchen0618/open_domain_data/NQ/dev_shu_gold_higher.json",
    )
}


def load_and_preprocess_dataset(args):
    dataset_class, url_or_path = DATASETS[args.dataset]
    print(dataset_class)
    dataset = dataset_class.new(args.dataset, url_or_path)
    dataset.preprocess(args.wikidata, args.ner_model, args.debug)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-d",
        "--dataset",
        choices=list(DATASETS.keys()),
        required=True,
        help=f"Name of the dataset. Must be one of {list(DATASETS.keys())}",
    )
    parser.add_argument(
        "-w",
        "--wikidata",
        default="wikidata/entity_info.json.gz",
        help="Path to wikidata entity info file generated in Stage 2.",
    )
    parser.add_argument(
        "-m",
        "--ner-model",
        default="models/kc-ner-model/",
        help="Path to the directory of our SpaCy Named Entity Recognition and Entity Linking model, downloaded during setup.",
    )
    parser.add_argument(
        "--debug",
        type=argparse_str2bool,
        nargs="?",
        const=True,
        default=False,
        help="If set to True, only 100 examples are processed, to speed up debugging.",
    )
    args = parser.parse_args()
    load_and_preprocess_dataset(args)
