import os
import json
import spacy

from tqdm import tqdm
from uuid import UUID
from random import Random

# for reproducible UUID generation
rd = Random()
rd.seed(42)

ONTONOTES_DIR = "ontonotes/"
nlp = spacy.load("en")
set2fmt = "debug"  # change this as required

pronoun_list = [
    "mine",
    "us",
    "you",
    "him",
    "whoever",
    "whomever",
    "themselves",
    "there",
    "your",
    "these",
    "where",
    "myself",
    "whose",
    "someone",
    "ourselves",
    "his",
    "whichever",
    "everybody",
    "yourselves",
    "anybody",
    "which",
    "our",
    "herself",
    "ours",
    "ourselves",
    "its",
    "my",
    "hers",
    "their",
    "her",
    "whosever",
    "whom",
    "yourself",
    "both",
    "she",
    "me",
    "himself",
    "itself",
    "I",
    "theirs",
    "those",
    "we",
    "he",
    "them",
    "who",
    "they",
    "somebody",
    "each other",
    "something",
    "it",
    "yours",
    "that",
    "others",
    "neither",
    "none",
    "wherever",
    "some",
    "thyself",
    "no one",
    "whereon",
    "thy",
    "whence",
    "whereof",
    "ye",
    "theirself",
    "whatever",
    "whatnot",
    "whether",
    "thee",
    "whosesoever",
    "anyone",
    "several",
    "many",
    "whereunto",
    "ourself",
    "thine",
    "anything",
    "such",
    "any",
    "all",
    "aught",
    "nobody",
    "somewhat",
    "either",
    "whoso",
    "themself",
    "suchlike",
    "whomsoever",
    "whichsoever",
    "wherewith",
    "everything",
    "idem",
    "nothing",
    "one another",
    "this",
    "wheresoever",
    "as",
    "most",
    "whosoever",
    "another",
    "naught",
    "thou",
    "whereto",
    "nought",
    "one",
    "other",
    "theirselves",
    "whatsoever",
    "yon",
    "whereby",
    "whomso",
    "everyone",
    "enough",
    "few",
    "wherefrom",
    "wherein",
    "whereinto",
    "wherewithal",
    "yonder",
    "what",
    "ought",
    "each",
]


def sort_cluster(cluster):
    return cluster.sort(key=lambda x: x[0])


def sort_clusters(clusters):
    for cluster in clusters:
        cluster = sort_cluster(cluster)
    return clusters


def convert_indices(text):
    counts = []
    for idx, t in enumerate(text):
        if idx == 0:
            counts.append((0, len(t)))
        else:
            counts.append((counts[idx - 1][1] + 1, counts[idx - 1][1] + 1 + len(t)))
    return counts


def enclosed(idx, boundary):
    return idx >= boundary[0] and idx < boundary[1]


data = []
# the jsonlines are generated from https://github.com/kentonl/e2e-coref/blob/master/minimize.py
with open(
    os.path.join(ONTONOTES_DIR, "{}.jsonlines".format(set2fmt)), "r", encoding="utf-8"
) as f:
    for row in f.readlines():
        data.append(json.loads(row))

squad_data = []
for dp in tqdm(data):
    clusters = sort_clusters(dp["clusters"])
    text = [word for sent in dp["sentences"] for word in sent]
    title = dp["doc_key"]
    str_spans = convert_indices(text)
    context = " ".join(text)

    squad_clusters = []
    for cluster in clusters:
        cluster_obj = {}
        # sometimes, the first mention of a cluster is a pronoun (for e.g: 'For their investigation, the FBI...')
        # in these cases, we want to make `FBI` the antecedent instead of `their`
        # Use below variable to keep track of such cases
        to_replace = False
        for idx, span in enumerate(cluster):
            start_idx, end_idx = str_spans[span[0]][0], str_spans[span[1]][1]
            # mark the span which occurs first as the antecedent (and initialize cluster)
            if idx == 0:
                # if it is a pronoun, fill antecedent with current value, but set flag for replacement
                if context[start_idx:end_idx] in pronoun_list:
                    to_replace = True
                cluster_obj["antecedent"] = {
                    "name": context[start_idx:end_idx],
                    "span": (start_idx, end_idx),
                }
                cluster_obj["referents"] = [
                    {
                        "name": context[start_idx:end_idx],
                        "span": (start_idx, end_idx),
                        "word_span": (span[0], span[1]),
                    }
                ]
            # append to referent list
            else:
                # if replacement flag is set and the current span is not a pronoun, replace the antecedent
                if to_replace and context[start_idx:end_idx] not in pronoun_list:
                    cluster_obj["antecedent"] = {
                        "name": context[start_idx:end_idx],
                        "span": (start_idx, end_idx),
                    }
                    to_replace = False
                cluster_obj["referents"].append(
                    {
                        "name": context[start_idx:end_idx],
                        "span": (start_idx, end_idx),
                        "word_span": (span[0], span[1]),
                    }
                )
        squad_clusters.append(cluster_obj)
    print(squad_clusters)
    doc = nlp(context)
    qas = []
    for sent in doc.sents:
        sent_boundary = (sent.start_char, sent.end_char)
        for cluster in squad_clusters:
            for ref in cluster["referents"]:
                ref_start = ref["span"][0]
                ref_end = ref["span"][1]
                if enclosed(ref_start, sent_boundary) and enclosed(
                    ref_end, sent_boundary
                ):
                    ref_start -= sent_boundary[0]
                    ref_end -= sent_boundary[0]
                    qas.append(
                        {
                            "id": UUID(int=rd.getrandbits(128)).hex,
                            "question": sent.text.replace(
                                " " + ref["name"] + " ",
                                "<ref> " + ref["name"] + " </ref>",
                            ),
                            "mention_span": ref["word_span"],
                            "answers": [
                                {
                                    "answer_start": cluster["antecedent"]["span"][0],
                                    "text": cluster["antecedent"]["name"],
                                }
                            ],
                        }
                    )
            print(qas)
    squad_data.append(
        {"title": title, "paragraphs": [{"context": context, "qas": qas}]}
    )

print("Flushing to disk...")
with open(
    os.path.join(ONTONOTES_DIR, "{}.json".format(set2fmt)), "w", encoding="utf-8"
) as f:
    json.dump({"version": "1.1", "data": squad_data}, f, indent=4)
