import json
import pandas as pd
from argparse import ArgumentParser
from transformers import BertTokenizer, BertTokenizerFast


def bio_to_idx_subword(comment, tags, line_idx, tokenizer, args):
    if args.tokenized:
        comment = tokenizer.decode(tokenizer.convert_tokens_to_ids(comment.strip().split(" ")))

    tokenized = tokenizer(comment, return_offsets_mapping=True, add_special_tokens=False)
    words = tokenizer.convert_ids_to_tokens(tokenized['input_ids'])
    offset_mapping = tokenized['offset_mapping']
    tags = tags.split(' ')

    assert len(words) == len(tags), f"{line_idx} has diff len : {len(words)} != {len(tags)}"

    def get_indices(start, end, offset):
        char_start_idx = offset[start][0]
        char_end_idx = offset[end][1]
        return list(range(char_start_idx, char_end_idx))

    start_idx = None
    span_indices = []
    for idx, curr_tag in enumerate(tags):
        if curr_tag != "O":
            if start_idx is None:
                start_idx = idx
        else:
            if start_idx is not None:
                span_indices.extend(get_indices(start_idx, idx - 1, offset_mapping))
                start_idx = None

    if start_idx is not None:
        span_indices.extend(get_indices(start_idx, len(tags) - 1, offset_mapping))

    return span_indices


def bio_to_idx(comment, tags, line_idx):
    words = comment.split(' ')
    tags = tags.split(' ')
    assert len(words) == len(tags), f"{line_idx} has diff len : {len(words)} != {len(tags)}"
    span_indices = []
    char_idx = 0

    for idx in range(len(words) - 1):
        curr_word = words[idx]
        curr_tag = tags[idx]
        if curr_tag != "O":
            for _ in range(len(curr_word)):
                span_indices.append(char_idx)
                char_idx += 1
            if tags[idx+1] != "O":  # continuous span
                span_indices.append(char_idx)
                char_idx += 1
            else:
                char_idx += 1
        else:
            for _ in range(len(curr_word)):
                char_idx += 1
            char_idx += 1

    if tags[-1] != "O":
        for _ in range(len(words[-1])):
            span_indices.append(char_idx)
            char_idx += 1

    return span_indices


def bio_to_idx_depr(comment, tags):
    words = comment.split(' ')
    tags = tags.split(' ')
    idx = []
    cur = 0
    for word, tag in zip(words, tags):
        if tag != 'O':
            for _ in range(len(word)):
                idx.append(cur)
                cur += 1
        else:
            for _ in range(len(word)):
                cur += 1
        cur += 1
    return idx


def convert_bio_to_idx(args, tokenizer):
    results = []
    with open(args.input_path, "r", encoding="utf-8") as f:
        next(f)     # skip the header
        for line_idx, line in enumerate(f):
            tab_split = line.strip().split("\t")
            if len(tab_split) != 7:
                raise ValueError(f"[ILL-FORMED][TAB-NO-7][{line_idx}][{line}]")

            guid, title, comment, pooled_gold, pooled_pred, bio_gold, bio_pred = tab_split[0], tab_split[1], tab_split[
                2], tab_split[3], tab_split[4], tab_split[5], tab_split[6]
            if args.label_all_tokens:
                pred_indices = bio_to_idx_subword(comment, bio_pred, line_idx, tokenizer, args)
            else:
                pred_indices = bio_to_idx(comment, bio_pred, line_idx)
            # depr_pred_indices = bio_to_idx_depr(comment, pred_bio)
            if args.mode == "off_span":
                results.append({
                    "id": guid,
                    "comment": comment,
                    "title": title,
                    "offensive": pooled_pred,
                    "off_span": pred_indices,
                    "tgt_span": None
                })
            elif args.mode == "tgt_span":
                results.append({
                    "id": guid,
                    "comment": comment,
                    "title": title,
                    "offensive": pooled_pred,
                    "off_span": None,
                    "tgt_span": pred_indices
                })
            elif args.mode == "group_span":
                results.append({
                    "id": guid,
                    "comment": comment,
                    "title": title,
                    "offensive": pooled_pred,
                    "off_span": None,
                    "tgt_span": pred_indices
                })
            else:
                raise ValueError(f"{args.mode} NOT SUPPORTED (off_span|tgt_span|group_span)")
    return results


def load_old_output(pred_output_dir: str, golds: list):
    model_output = pd.read_csv(pred_output_dir, sep='\t')
    preds = []
    for (idx, df), gold in zip(model_output.iterrows(), golds):
        preds.append(
            {
                "id": gold['text_id'],
                "comment": gold['text'],
                "title": gold['title'],
                "offensive": True if 'off' in df['pred'] else False,
                "off_span": bio_to_idx_depr(df['text'], df['pred']),
                "tgt_span": None,
            }
        )
    return preds


def main():
    parser = ArgumentParser()
    parser.add_argument("--input_path", type=str, default="../baseline/ckpt/v2.1/sp-off-0/test_prediction.tsv")
    parser.add_argument("--output_path", type=str, default="../baseline/ckpt/v2.1/sp-off-0/test_prediction.json")
    parser.add_argument("--mode", type=str, choices=["off_span", "tgt_span", "group_span"], required=True)

    parser.add_argument("--tokenized", action='store_true')
    parser.add_argument("--label_all_tokens", action='store_true')
    parser.add_argument("--pretrained_path", type=str, default="../klue-bert-base")

    parser.add_argument("--gold_path", type=str, default="../baseline/data/hatespan/test.json")
    parser.add_argument("--old_version", action='store_true')

    args = parser.parse_args()

    tokenizer = BertTokenizerFast.from_pretrained(args.pretrained_path)
    if args.old_version:
        with open(args.gold_path, "r", encoding="utf-8") as f_gold:
            gold_json = json.load(f_gold)
        results = load_old_output(args.input_path, gold_json)
    else:
        results = convert_bio_to_idx(args, tokenizer)

    with open(args.output_path, "w", encoding="utf-8") as f_out:
        json.dump(results, f_out, ensure_ascii=False, indent=2)


if __name__ == "__main__":
    main()
