import argparse
import json
import pandas as pd
from datasets import load_dataset


def load_gold(gold_dir: str, offensive_only: bool):
    def _vote_to_idx(vote: list, threshold=2):
        idx = []
        for i, v in enumerate(vote):
            if v >= threshold:
                idx.append(i)
        return idx

    if gold_dir.endswith("jsonl"):
        test_dataset = load_dataset("json", data_files=gold_dir)["train"]
    elif gold_dir.endswith("json"):
        test_dataset = json.load(open(gold_dir, "r", encoding='utf-8'))
    else:
        raise ValueError(f"Unsupported file type: {gold_dir}")

    # reformat test data
    gold = []
    for data in test_dataset:
        if offensive_only:
            if data["offensive"]:
                gold_data = {
                    "text_id": data["text_id"],
                    "comment": data["comment"],
                    "title": data["title"],
                    "offensive": True if data["offensive"] else False,
                    "off_span": _vote_to_idx(data["off_span_list"]),
                    "tgt_span": _vote_to_idx(data["tgt_span_list"]),
                }
                gold.append(gold_data)
        else:
            gold_data = {
                "text_id": data["text_id"],
                "comment": data["comment"],
                "title": data["title"],
                "offensive": True if data["offensive"] else False,
                "off_span": _vote_to_idx(data["off_span_list"]),
                "tgt_span": _vote_to_idx(data["tgt_span_list"]),
            }
            gold.append(gold_data)
    return gold


def startend_to_idx(start: int, end: int):
    idx = []
    for i in range(start, end, 1):
        idx.append(i)
    return idx


def load_jaimeen_output(pred_data_dir: str, gold: list):
    jaimeen_output_dir = pred_data_dir
    jaimeen_output = json.load(open(jaimeen_output_dir))

    # reformat jaimeen output
    jaimeen_pred = []
    for key, g in zip(jaimeen_output.keys(), gold):
        pred = jaimeen_output[key]

        comment = g["comment"]
        start_idx = comment.find(pred[0])
        end_idx = start_idx + len(pred[0])

        jaimeen_pred.append(
            {
                "guid": g["guid"],
                "comment": g["comment"],
                "title": g["title"],
                "offensive": True if pred[3] == 1 else False,
                "off_span": startend_to_idx(start_idx, end_idx),
                "tgt_span": None,
            }
        )


def load_jongwon_output(pred_output_dir: str, gold: list):
    model_output = json.load(open(pred_output_dir))
    preds = []
    for key, g in zip(model_output.keys(), gold):
        o = model_output[key]
        comment = g["comment"]
        start_idx = comment.find(o["text"])
        end_idx = start_idx + len(o["text"])
        preds.append(
            {
                "id": g["guid"],
                "comment": g["comment"],
                "title": g["title"],
                "offensive": None,
                "off_span": startend_to_idx(start_idx, end_idx),
                "tgt_span": None,
            }
        )
    return preds


def format_output(pred_dir: str, gold_dir: str, mode: int, offensive_only: bool):
    def _bio_to_idx(comment, tags):
        words = comment.split(" ")
        tags = tags.split(" ")
        idx = []
        cur = 0
        for word, tag in zip(words, tags):
            if tag != "O":
                for _ in range(len(word)):
                    idx.append(cur)
                    cur += 1
            else:
                for _ in range(len(word)):
                    cur += 1
            cur += 1
        return idx

    gold = load_gold(gold_dir, offensive_only)
    if mode == 0:  # younghoon model
        model_output = pd.read_csv(pred_dir, sep="\t")
        preds = []
        for (i, o), g in zip(model_output.iterrows(), gold):
            preds.append(
                {
                    "id": g["guid"],
                    "text": g["text"],
                    "title": g["title"],
                    "offensive": True if "off" in o["pred"] else False,
                    "off_span": _bio_to_idx(o["text"], o["pred"]),
                    "tgt_span": None,
                }
            )

    elif mode == 1:  # jaimeen model

        preds = load_jaimeen_output(pred_dir, gold)

    elif mode == 2:  # jongwon model
        preds = load_jongwon_output(pred_dir, gold)

    else:
        raise ValueError(f"Unsupported mode : {mode}")

    return preds


def save_output(preds, output_dir):
    json.dump(output, open(output_dir, "w", encoding="UTF-8"), ensure_ascii=False)
    return


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--pred_dir", type=str, required=True)
    parser.add_argument("--gold_dir", type=str, required=True)
    parser.add_argument("--mode", type=int, required=True)
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument("--offensive_only", action='store_true')
    args = parser.parse_args()

    output = format_output(args.pred_dir, args.gold_dir, args.mode, args.offensive_only)
    save_output(output, args.output_dir)
    print(f"{args.pred_dir} reformatted and saved in {args.output_dir}")
