"""
This script is used to generate adversary dataset based on the original dataset.
For every example, we find the sentence which has the max lexical overlap with the question.
And we remove all the other sentences in the context and keep only this lexical overlapped one.
Thus generate a new adversary set.
"""
import copy
import gzip
import json
import nltk
import os
import pandas
import shutil
import sys
import torch
from sentence_transformers import SentenceTransformer
from spacy.lang.en import English
import uuid

import os
os.environ["OMP_NUM_THREADS"] = "6"
os.environ["MKL_NUM_THREADS"] = "6"

nlp = English()
tokenizer = nlp.Defaults.create_tokenizer(nlp)
nlp.add_pipe(nlp.create_pipe('sentencizer'))
model = SentenceTransformer("bert-base-nli-mean-tokens", device='cuda')


def get_context_tokens(context):
    tokens = []
    for token in tokenizer(context):
        tokens.append([str(token), token.idx])
    return tokens


def get_detected_answers(ans, context):
    # print(ans)
    # print(context)
    ans_dict = {"text": ans, "answer_start": -1}
    context_tokens = tokenizer(context)

    # find all the appearance of this answer in the context
    ans_starts = [i for i in range(len(context)) if context.lower().startswith(ans.lower(), i)]

    for token in context_tokens:
        if token.idx in ans_starts:
            ans_dict["answer_start"] = token.idx
            break

    return ans_dict


def get_all_sentences(context):
    doc = nlp(context)
    ss = []
    for sent in doc.sents:
        ss.append(sent.text)
    return ss


def find_lexical_overlap_sentence(sents, q_tokens):
    # find the highest lexical overlap sentence by counting the common words
    question_tokens = [token[0] for token in q_tokens]
    overlap_length = 0
    max_overlap_sentence = ""
    # if two sentences have the same overlap with the question, only keeps one
    for sent in sents:
        ss = []
        for token in tokenizer(sent):
            ss.append(str(token))
        # exclude "the" when count lexical overlap
        overlap = [t for t in ss if t in question_tokens and t != "the"]
        if len(overlap) > overlap_length:
            overlap_length = len(overlap)
            max_overlap_sentence = sent
    return max_overlap_sentence


def find_lexical_overlap_sentence_embedding(sents, question):
    # find the highest lexical overlap sentence by consine similarity of the sentence embeddings.
    cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-08)
    sentence_embeddings = model.encode(sents)
    question_embedding = model.encode([question])[0]
    max_sim = 0
    most_similar_sentence = ""
    # if two sentences have the same overlap with the question, only keeps one
    for sentence, sent_embedding in zip(sents, sentence_embeddings):
        cos_sim = cos(torch.FloatTensor(sent_embedding), torch.FloatTensor(question_embedding))
        if cos_sim >= max_sim:
            max_sim = cos_sim
            most_similar_sentence = sentence
    # print(question, most_similar_sentence, max_sim)
    return most_similar_sentence


def generate_new_example(content, qa):
    sentences = get_all_sentences(content['context'])
    sentence_to_keep = find_lexical_overlap_sentence_embedding(sentences, qa["question"])
    new_content = copy.deepcopy(content)
    new_content['context'] = sentence_to_keep
    new_qa = copy.deepcopy(qa)
    new_qa["answers"] = []
    for answer in qa['answers']:
        if answer["text"].lower() in sentence_to_keep.lower():
            ans_dict = get_detected_answers(answer["text"], sentence_to_keep)
            if ans_dict["answer_start"] != -1:
                count["has_anwer"] += 1
                new_qa["answers"].append(ans_dict)
    if new_qa["answers"]:
        new_content["qas"] = [new_qa]
    else:
        new_content["qas"] = []
    return new_content


if __name__ == '__main__':
    src_folder = sys.argv[1]
    out_folder = sys.argv[2]

    f_out = open(os.path.join(out_folder, "train-bart-lexical.json"), "w")
    count = {"has_anwer": 0, "total_questions": 0}

    with open(os.path.join(src_folder, "train-bart.json")) as f_in:
        data = json.load(f_in)["data"]
        new_data = {"data": []}
        for paragraphs in data:
            new_para = {"title": paragraphs["title"], "url": "", "paragraphs": []}
            for context in paragraphs["paragraphs"]:
                i = 0
                for qa in context["qas"]:
                    print(qa["question"])
                    i += 1
                    reduced_context = generate_new_example(context, qa)
                    # reduced_context["context_id"] = context["context_id"]+str(i)
                    if reduced_context["qas"]:
                        new_para["paragraphs"].append(reduced_context)
            if new_para["paragraphs"]:
                new_data["data"].append(new_para)

    json.dump(new_data, f_out, indent=4)
    f_out.close()
    print(count)
