import argparse
import json
import sys
from pathlib import Path

RECOGNIZED_DATASETS = ["mediasum", "ultrachat"]


def convert_to_selfee_format(input_file: Path, output_file: Path, dataset: str):
    dataset = dataset.lower()
    allowed_datasets = ["mediasum", "ultrachat"]
    if dataset not in allowed_datasets:
        raise ValueError(f"{dataset=} is not in {allowed_datasets=}")

    output_columns = ["question_id", "text", "answer", "category"]

    with open(input_file) as inp, open(output_file, "w+") as out:
        for row in inp:
            row = json.loads(row)
            if dataset == "mediasum":
                row["question_id"] = "|".join([row["doc_id"], row["topic"], row["model_name"]])
                row["text"] = (
                    "Please summarize the following document on the topic {topic}:\n{document}".format(
                        topic=row["topic"], document=row["source_doc"]
                    )
                )
                row["answer"] = row["summary"]
            elif dataset == "ultrachat":
                row["question_id"] = row["id"]
                row["text"] = row["instruction"]
                row["answer"] = row["completions"][0]

            row["category"] = "generic"

            out.write(json.dumps({c: row[c] for c in output_columns}) + "\n")
    print(output_file)


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_file", type=Path)
    parser.add_argument("--output_file", type=Path)
    parser.add_argument("--dataset", type=str, help=f"Option from {RECOGNIZED_DATASETS}")

    args = parser.parse_args()

    dataset = args.dataset
    if dataset not in RECOGNIZED_DATASETS:
        raise ValueError("{dataset=} not in {RECOGNIZED_DATASETS=}")

    convert_to_selfee_format(
        input_file=args.input_file,
        output_file=args.output_file,
        dataset=args.dataset,
    )

    return 0


if __name__ == "__main__":
    sys.exit(main())
