"""
Data augmentation (transformations) operations used to generate 
synthetic training data for the `FactCC` and `FactCCX` models.
"""

import random

import spacy
import signal

# from google.cloud import translate
from sentence_transformers import SentenceTransformer, util
from LexRank import degree_centrality_scores
import numpy as np
import collections
import logging
import json
import unidecode
logger = logging.getLogger("spacy")
logger.setLevel(logging.ERROR)


class TimeoutException(Exception):  # Custom exception class
    pass


def timeout_handler(signum, frame):  # Custom signal handler
    raise TimeoutException


# Change the behavior of SIGALRM
signal.signal(signal.SIGALRM, timeout_handler)

LABEL_MAP = {True: "CORRECT", False: "INCORRECT"}


def align_ws(old_token, new_token):
    # Align trailing whitespaces between tokens
    if old_token[-1] == new_token[-1] == " ":
        return new_token
    elif old_token[-1] == " ":
        return new_token + " "
    elif new_token[-1] == " ":
        return new_token[:-1]
    else:
        return new_token

def update_example_with_negative(eid=None, text=None, positive_example=None, label=None, extraction_span=None,
                     backtranslation=None, augmentation=None, augmentation_span=None, noise=None):
    # Embed example information in a json object.
    return {
        "id": eid,
        "text": text,
        "positive_example": { 'text': positive_example,
        },
        "extraction_span": extraction_span,
        "backtranslation": backtranslation,
        "augmentation": augmentation,
        "augmentation_span": augmentation_span,
        "noise": noise
    }


def make_new_example(eid=None, text=None, positive_example=None, sentences=None):
    # Embed example information in a json object.
    return {
        "id": eid,
        "text": text,
        "positive_example": positive_example,
        "sentences": sentences,
    }


class Transformation():
    # Base class for all data transformations

    def __init__(self):
        # Spacy toolkit used for all NLP-related substeps
        self.spacy = spacy.load("en_core_web_sm")

    def transform(self, example):
        # Function applies transformation on passed example
        pass


from transformers import BertTokenizer

bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")


class FormatTransformation(Transformation):
    # add new keys to the example
    def __init__(self):
        super().__init__()

    def transform(self, example):
        page_doc = self.spacy(example["text"], disable=["tagger"])
        claim, summary = self.spacy(example["summary"].replace("\n", " ")), self.spacy(example["summary"].replace("\n", " "))
        new_example = make_new_example(eid=example["id"],
                                       text=page_doc,
                                       claim=claim,
                                       label=LABEL_MAP[True],
                                       backtranslation=False, noise=False)
        return new_example


class SampleSentences512(Transformation):
    # Embed document as Spacy object and sample one sentence as claim
    def __init__(self, min_sent_len=8):
        super().__init__()
        self.min_sent_len = min_sent_len

    def transform(self, example):
        assert example["text"] is not None, "Text must be available"

        # split into sentences
        page_id = example["id"]
        page_text = example["text"].replace("\n", " ")

        tok = bert_tokenizer(page_text, return_length=True, max_length=512, truncation=True)
        page_text_detok = bert_tokenizer.decode(tok.input_ids, skip_special_tokens=True)

        page_doc = self.spacy(page_text, disable=["tagger"])
        sents = [sent for sent in page_doc.sents if len(sent) >= self.min_sent_len]

        # sample claim
        claims = []
        while True:
            claim = random.choice(sents)

            claim_text = claim.text

            tok = bert_tokenizer(claim_text, return_length=True, max_length=512, truncation=True)
            claim_text = bert_tokenizer.decode(tok.input_ids, skip_special_tokens=True)

            if claim_text in page_text_detok:
                claims.append(claim)
                if len(claims) == 4:
                    break

        new_examples = []
        for claim in claims:
            new_example = make_new_example(eid=page_id, text=page_doc.text,
                                           positive_example=claim.text)
            new_examples.append(new_example)

        return new_examples


class SummarySentences(Transformation):
    # Embed document as Spacy object and sample one sentence as claim
    def __init__(self, min_sent_len=8):
        super().__init__()
        self.min_sent_len = min_sent_len

    def transform(self, example):
        assert example["text"] is not None, "Text must be available"

        # split into sentences
        page_id = example["id"]

        page_text = example["text"].replace("\n", " ")
        page_doc = self.spacy(page_text, disable=["tagger"])

        page_sum = example["summary"].replace("\n", " ")
        page_sum = self.spacy(page_sum, disable=["tagger"])
        sents = [sent for sent in page_sum.sents]

        new_examples = []
        for claim in sents:
            new_example = make_new_example(eid=page_id, text=page_doc.text,
                                           positive_example=claim.text)
            new_examples.append(new_example)

        return new_examples


class SampleSentences(Transformation):
    # Embed document as Spacy object and sample one sentence as claim
    def __init__(self, min_sent_len=8):
        super().__init__()
        self.min_sent_len = min_sent_len

    def transform(self, example):
        assert example["text"] is not None, "Text must be available"

        # split into sentences
        page_id = example["id"]
        page_text = example["text"].replace("\n", " ")
        page_doc = self.spacy(page_text, disable=["tagger"])
        sents = [sent for sent in page_doc.sents if len(sent) >= self.min_sent_len]

        # sample claim
        claim = random.choice(sents)
        new_example = make_new_example(eid=page_id, text=page_doc,
                                       claim=self.spacy(claim.text),
                                       label=LABEL_MAP[True],
                                       extraction_span=(claim.start, claim.end - 1),
                                       backtranslation=False, noise=False)
        return new_example


class SentenceLexRank(Transformation):
    # Embed document as Spacy object and sample one sentence as claim
    def __init__(self, min_sent_len=8):
        super().__init__()
        self.min_sent_len = min_sent_len
        self.count_at_random = 0
        self.model_sentence_transf = SentenceTransformer('paraphrase-mpnet-base-v2')

    def transform(self, example):
        assert example["text"] is not None, "Text must be available"

        # split into sentences
        page_id = example["id"]
        page_text = example["text"].replace("\n", " ")
        page_doc = self.spacy(page_text, disable=["tagger"])
        sents = [sent for sent in page_doc.sents if len(sent) >= self.min_sent_len]

        # print(sents)

        sents_text = [sent.text for sent in sents]

        assert len(sents) == len(sents_text)

        signal.alarm(5)

        try:

            embs = self.model_sentence_transf.encode(sents_text, convert_to_tensor=True)

            # Compute the pair-wise cosine similarities
            cos_scores = util.pytorch_cos_sim(embs, embs).cpu().detach().numpy()

            # print(cos_scores)

            # Compute the centrality for each sentence
            centrality_scores = degree_centrality_scores(cos_scores, threshold=None)

            # We argsort so that the first element is the sentence with the highest score
            most_central_sentence_indices = np.argsort(-centrality_scores)

            claim = sents[most_central_sentence_indices[0]]
        except TimeoutException:
            claim = random.choice(sents)
            self.count_at_random += 1
            print("choose at random", self.count_at_random)
        else:
            # Reset the alarm
            signal.alarm(0)

        # print(claim)
        # exit()

        # sample claim
        new_example = make_new_example(eid=page_id, text=page_doc,
                                       claim=self.spacy(claim.text),
                                       label=LABEL_MAP[True],
                                       extraction_span=(claim.start, claim.end - 1),
                                       backtranslation=False, noise=False)
        return new_example


class SentenceLexRank512DevSum(Transformation):
    # Embed document as Spacy object and sample one sentence as claim
    def __init__(self, min_sent_len=8):
        super().__init__()
        self.min_sent_len = min_sent_len
        self.count_at_random = 0
        self.model_sentence_transf = SentenceTransformer('paraphrase-mpnet-base-v2')

    def transform(self, example):
        assert example["text"] is not None, "Text must be available"

        # split into sentences
        page_text = example["text"].replace("\n", " ")
        summary = example["claim"].replace("\n", " ")

        page_doc = self.spacy(page_text, disable=["tagger"])
        sents = [sent for sent in page_doc.sents if len(sent) >= self.min_sent_len]

        # print(sents)

        sents_text = [sent.text for sent in sents]

        assert len(sents) == len(sents_text)

        signal.alarm(10)

        try:

            embs = self.model_sentence_transf.encode(sents_text, convert_to_tensor=True)
            embs_sum = self.model_sentence_transf.encode([summary], convert_to_tensor=True)

            # Compute the pair-wise cosine similarities
            cos_scores = util.pytorch_cos_sim(embs, embs_sum).cpu().detach().numpy()

            # print(cos_scores)

            # Compute the centrality for each sentence
            centrality_scores = np.mean(cos_scores, axis=1)

            # We argsort so that the first element is the sentence with the highest score
            most_central_sentence_indices = np.argsort(-centrality_scores)

            claims = []
            best_sents = {}
            for idx in most_central_sentence_indices:
                tmp_claim = sents[idx]
                claim_text = tmp_claim.text
                best_sents[idx] = claim_text

                if len(best_sents) == 5:
                    break

            best_sents_list = []
            d = collections.OrderedDict(sorted(best_sents.items()))
            for k, v in d.items():
                best_sents_list.append(v)


        except TimeoutException:
            self.count_at_random += 1
            print("return none error", self.count_at_random)
            return None
        else:
            # Reset the alarm
            signal.alarm(0)


        # print(claim)
        # exit()

        # sample claim
        example['sentences'] = best_sents_list

        return example


class SentenceLexRank512Dev(Transformation):
    # Embed document as Spacy object and sample one sentence as claim
    def __init__(self, min_sent_len=8):
        super().__init__()
        self.min_sent_len = min_sent_len
        self.count_at_random = 0
        self.model_sentence_transf = SentenceTransformer('paraphrase-mpnet-base-v2')

    def transform(self, example):
        assert example["text"] is not None, "Text must be available"

        # split into sentences
        page_text = example["text"].replace("\n", " ")

        page_doc = self.spacy(page_text, disable=["tagger"])
        sents = [sent for sent in page_doc.sents if len(sent) >= self.min_sent_len]

        # print(sents)

        sents_text = [sent.text for sent in sents]

        assert len(sents) == len(sents_text)

        signal.alarm(10)

        try:

            embs = self.model_sentence_transf.encode(sents_text, convert_to_tensor=True)

            # Compute the pair-wise cosine similarities
            cos_scores = util.pytorch_cos_sim(embs, embs).cpu().detach().numpy()

            # print(cos_scores)

            # Compute the centrality for each sentence
            centrality_scores = degree_centrality_scores(cos_scores, threshold=None)

            # We argsort so that the first element is the sentence with the highest score
            most_central_sentence_indices = np.argsort(-centrality_scores)

            claims = []
            best_sents = {}
            for idx in most_central_sentence_indices:
                tmp_claim = sents[idx]
                claim_text = tmp_claim.text
                best_sents[idx] = claim_text

                if len(best_sents) == 5:
                    break

            best_sents_list = []
            d = collections.OrderedDict(sorted(best_sents.items()))
            for k, v in d.items():
                best_sents_list.append(v)


        except TimeoutException:
            self.count_at_random += 1
            print("return none error", self.count_at_random)
            return None
        else:
            # Reset the alarm
            signal.alarm(0)


        # print(claim)
        # exit()

        # sample claim
        example['sentences'] = best_sents_list

        return example


class SentenceLexRankSum512_2(Transformation):
    # Embed document as Spacy object and sample one sentence as claim
    def __init__(self, min_sent_len=8):
        super().__init__()
        self.min_sent_len = min_sent_len
        self.count_at_random = 0
        self.model_sentence_transf = SentenceTransformer('paraphrase-mpnet-base-v2')

    def transform(self, example):
        assert example["text"] is not None, "Text must be available"

        # split into sentences
        page_id = example["id"]
        page_text = example["text"].replace("\n", " ")
        summaries = example["summary"].replace("\n", " ")


        tok = bert_tokenizer(page_text, return_length=True, max_length=512, truncation=True)
        page_text_detok = bert_tokenizer.decode(tok.input_ids, skip_special_tokens=True)

        page_doc = self.spacy(page_text, disable=["tagger"])
        sents = [sent for sent in page_doc.sents if len(sent) >= self.min_sent_len]

        page_sum = self.spacy(summaries, disable=["tagger"])
        sents_sum = [sent for sent in page_sum.sents]

        # print(sents)

        sents_text = [sent.text for sent in sents]

        sents_text_sum = [sent.text for sent in sents_sum]

        assert len(sents) == len(sents_text)

        signal.alarm(10)

        try:

            embs = self.model_sentence_transf.encode(sents_text, convert_to_tensor=True)

            embs_sum = self.model_sentence_transf.encode(sents_text_sum, convert_to_tensor=True)

            # Compute the pair-wise cosine similarities
            cos_scores = util.pytorch_cos_sim(embs, embs_sum).cpu().detach().numpy()

            centrality_scores = np.mean(cos_scores, axis=1)

            # import pdb
            # pdb.set_trace()

            # print(cos_scores)

            # Compute the centrality for each sentence
            #centrality_scores = degree_centrality_scores(cos_scores, threshold=None)

            # We argsort so that the first element is the sentence with the highest score
            most_central_sentence_indices = np.argsort(-centrality_scores)

            #assert len(most_central_sentence_indices) == len(sents)

            claims = []
            for s in sents_sum:
                claims.append(s)


            best_sents = {}
            for idx in most_central_sentence_indices:
                tmp_claim = sents[idx]
                claim_text = tmp_claim.text
                best_sents[idx] = claim_text

                if len(best_sents) == 5:
                    break

            best_sents_list = []
            d = collections.OrderedDict(sorted(best_sents.items()))
            for k, v in d.items():
                best_sents_list.append(v)


        except TimeoutException:
            self.count_at_random += 1
            print("return none error", self.count_at_random)
            return None
        else:
            # Reset the alarm
            signal.alarm(0)
            if not claims:
                self.count_at_random += 1
                print("return none, no textrank claim in 512", self.count_at_random)
                return None

        # print(claim)
        # exit()

        # sample claim
        new_examples = []
        for claim in claims:
            new_example = make_new_example(eid=page_id, text=page_doc.text,
                                           positive_example=claim.text, sentences=best_sents_list)
            new_examples.append(new_example)
        return new_examples


class SelectSentencesContrastive(Transformation):
    # Embed document as Spacy object and sample one sentence as claim
    def __init__(self, min_sent_len=8):
        super().__init__()
        self.min_sent_len = min_sent_len
        self.count_at_random = 0
        self.model_sentence_transf = SentenceTransformer('paraphrase-mpnet-base-v2')

    def transform(self, example):
        assert example["text"] is not None, "Text must be available"

        page_text = example["text"]
        summaries = example["positive_example"]


        page_doc = self.spacy(page_text, disable=["tagger"])
        sents = [sent for sent in page_doc.sents if len(sent) >= self.min_sent_len]

        page_sum = self.spacy(summaries, disable=["tagger"])
        sents_sum = [sent for sent in page_sum.sents]

        sents_text = [sent.text for sent in sents]
        sents_text_sum = [sent.text for sent in sents_sum]

        assert len(sents) == len(sents_text)

        signal.alarm(10)

        try:
            embs = self.model_sentence_transf.encode(sents_text, convert_to_tensor=True)
            embs_sum = self.model_sentence_transf.encode(sents_text_sum, convert_to_tensor=True)

            # Compute the pair-wise cosine similarities
            cos_scores = util.pytorch_cos_sim(embs, embs_sum).cpu().detach().numpy()

            centrality_scores = np.mean(cos_scores, axis=1)

            # import pdb
            # pdb.set_trace()

            # print(cos_scores)

            # Compute the centrality for each sentence
            #centrality_scores = degree_centrality_scores(cos_scores, threshold=None)

            # We argsort so that the first element is the sentence with the highest score
            most_central_sentence_indices = np.argsort(-centrality_scores)

            #assert len(most_central_sentence_indices) == len(sents)

            best_sents = {}
            for idx in most_central_sentence_indices:
                tmp_claim = sents[idx]
                claim_text = tmp_claim.text
                best_sents[idx] = claim_text

                if len(best_sents) == 5:
                    break

            best_sents_list = []
            d = collections.OrderedDict(sorted(best_sents.items()))
            for k, v in d.items():
                best_sents_list.append(v)

        except TimeoutException:
            self.count_at_random += 1
            print("return none error", self.count_at_random)
            return None
        else:
            # Reset the alarm
            signal.alarm(0)
            if not best_sents_list:
                self.count_at_random += 1
                print("return none, no sents", self.count_at_random)
                return None

        example['sentences'] = best_sents_list
        return example


class SelectSentencesScore(Transformation):
    # Embed document as Spacy object and sample one sentence as claim
    def __init__(self, min_sent_len=8):
        super().__init__()
        self.min_sent_len = min_sent_len
        self.count_at_random = 0
        self.model_sentence_transf = SentenceTransformer('paraphrase-mpnet-base-v2')

    def transform(self, example):
        assert example["article"] is not None, "Text must be available"

        example["article"] = unidecode.unidecode(example["article"])
        example["summary"] = unidecode.unidecode(example["summary"])

        page_text = example["article"]
        summaries = example["summary"]

        page_doc = self.spacy(page_text, disable=["tagger"])
        #sents = [sent for sent in page_doc.sents if len(sent) >= self.min_sent_len]
        sents = [sent for sent in page_doc.sents]

        page_sum = self.spacy(summaries, disable=["tagger"])
        sents_sum = [sent for sent in page_sum.sents]

        sents_text = [sent.text for sent in sents]
        sents_text_sum = [sent.text for sent in sents_sum]

        assert len(sents) == len(sents_text)

        signal.alarm(15)

        try:
            embs = self.model_sentence_transf.encode(sents_text, convert_to_tensor=True)
            embs_sum = self.model_sentence_transf.encode(sents_text_sum, convert_to_tensor=True)

            # Compute the pair-wise cosine similarities
            cos_scores = util.pytorch_cos_sim(embs, embs_sum).cpu().detach().numpy()

            centrality_scores = np.mean(cos_scores, axis=1)

            # import pdb
            # pdb.set_trace()

            # print(cos_scores)

            # Compute the centrality for each sentence
            #centrality_scores = degree_centrality_scores(cos_scores, threshold=None)

            # We argsort so that the first element is the sentence with the highest score
            most_central_sentence_indices = np.argsort(-centrality_scores)

            #assert len(most_central_sentence_indices) == len(sents)

            best_sents = []
            for idx in most_central_sentence_indices:
                tmp_claim = sents[idx]
                claim_text = tmp_claim.text
                best_sents.append((claim_text, int(idx), float(centrality_scores[idx])))


        except TimeoutException:
            self.count_at_random += 1
            print("return none error", self.count_at_random)
            return None
        else:
            # Reset the alarm
            signal.alarm(0)
            if not best_sents:
                self.count_at_random += 1
                print("return none, no sents", self.count_at_random)
                return None

        example['sentences'] = json.dumps(best_sents)
        return example


class SelectSentences(Transformation):
    # Embed document as Spacy object and sample one sentence as claim
    def __init__(self, min_sent_len=8):
        super().__init__()
        self.min_sent_len = min_sent_len
        self.count_at_random = 0
        self.model_sentence_transf = SentenceTransformer('paraphrase-mpnet-base-v2')

    def transform(self, example):
        assert example["text"] is not None, "Text must be available"

        page_text = example["text"]
        summaries = example["claim"]


        page_doc = self.spacy(page_text, disable=["tagger"])
        sents = [sent for sent in page_doc.sents if len(sent) >= self.min_sent_len]

        page_sum = self.spacy(summaries, disable=["tagger"])
        sents_sum = [sent for sent in page_sum.sents]

        sents_text = [sent.text for sent in sents]
        sents_text_sum = [sent.text for sent in sents_sum]

        assert len(sents) == len(sents_text)

        signal.alarm(10)

        try:
            embs = self.model_sentence_transf.encode(sents_text, convert_to_tensor=True)
            embs_sum = self.model_sentence_transf.encode(sents_text_sum, convert_to_tensor=True)

            # Compute the pair-wise cosine similarities
            cos_scores = util.pytorch_cos_sim(embs, embs_sum).cpu().detach().numpy()

            centrality_scores = np.mean(cos_scores, axis=1)

            # import pdb
            # pdb.set_trace()

            # print(cos_scores)

            # Compute the centrality for each sentence
            #centrality_scores = degree_centrality_scores(cos_scores, threshold=None)

            # We argsort so that the first element is the sentence with the highest score
            most_central_sentence_indices = np.argsort(-centrality_scores)

            #assert len(most_central_sentence_indices) == len(sents)

            best_sents = {}
            for idx in most_central_sentence_indices:
                tmp_claim = sents[idx]
                claim_text = tmp_claim.text
                best_sents[idx] = claim_text

                if len(best_sents) == 5:
                    break

            best_sents_list = []
            d = collections.OrderedDict(sorted(best_sents.items()))
            for k, v in d.items():
                best_sents_list.append(v)

        except TimeoutException:
            self.count_at_random += 1
            print("return none error", self.count_at_random)
            return None
        else:
            # Reset the alarm
            signal.alarm(0)
            if not best_sents_list:
                self.count_at_random += 1
                print("return none, no sents", self.count_at_random)
                return None

        example['sentences'] = best_sents_list
        return example


class SentenceLexRankSum512(Transformation):
    # Embed document as Spacy object and sample one sentence as claim
    def __init__(self, min_sent_len=8):
        super().__init__()
        self.min_sent_len = min_sent_len
        self.count_at_random = 0
        self.model_sentence_transf = SentenceTransformer('paraphrase-mpnet-base-v2')

    def transform(self, example):
        assert example["text"] is not None, "Text must be available"

        # split into sentences
        page_id = example["id"]
        page_text = example["text"].replace("\n", " ")
        summaries = example["summary"].replace("\n", " ")


        tok = bert_tokenizer(page_text, return_length=True, max_length=512, truncation=True)
        page_text_detok = bert_tokenizer.decode(tok.input_ids, skip_special_tokens=True)

        page_doc = self.spacy(page_text, disable=["tagger"])
        sents = [sent for sent in page_doc.sents if len(sent) >= self.min_sent_len]

        page_sum = self.spacy(summaries, disable=["tagger"])
        sents_sum = [sent for sent in page_sum.sents]

        # print(sents)

        sents_text = [sent.text for sent in sents]

        sents_text_sum = [sent.text for sent in sents_sum]

        assert len(sents) == len(sents_text)

        signal.alarm(10)

        try:

            embs = self.model_sentence_transf.encode(sents_text, convert_to_tensor=True)

            embs_sum = self.model_sentence_transf.encode(sents_text_sum, convert_to_tensor=True)

            # Compute the pair-wise cosine similarities
            cos_scores = util.pytorch_cos_sim(embs, embs_sum).cpu().detach().numpy()

            centrality_scores = np.mean(cos_scores, axis=1)

            # import pdb
            # pdb.set_trace()

            # print(cos_scores)

            # Compute the centrality for each sentence
            #centrality_scores = degree_centrality_scores(cos_scores, threshold=None)

            # We argsort so that the first element is the sentence with the highest score
            most_central_sentence_indices = np.argsort(-centrality_scores)

            #assert len(most_central_sentence_indices) == len(sents)

            claims = []
            best_sents = {}
            for idx in most_central_sentence_indices:
                tmp_claim = sents[idx]
                claim_text = tmp_claim.text
                best_sents[idx] = claim_text

                tok = bert_tokenizer(claim_text, return_length=True, max_length=512, truncation=True)
                claim_text = bert_tokenizer.decode(tok.input_ids, skip_special_tokens=True)

                if claim_text in page_text_detok:
                    if len(claims) < 4:
                        claim = sents[idx]
                        claims.append(claim)

                if len(best_sents) == 5:
                    break

            best_sents_list = []
            d = collections.OrderedDict(sorted(best_sents.items()))
            for k, v in d.items():
                best_sents_list.append(v)


        except TimeoutException:
            self.count_at_random += 1
            print("return none error", self.count_at_random)
            return None
        else:
            # Reset the alarm
            signal.alarm(0)
            if not claims:
                self.count_at_random += 1
                print("return none, no textrank claim in 512", self.count_at_random)
                return None

        # print(claim)
        # exit()

        # sample claim
        new_examples = []
        for claim in claims:
            new_example = make_new_example(eid=page_id, text=page_doc.text,
                                           positive_example=claim.text, sentences=best_sents_list)
            new_examples.append(new_example)
        return new_examples


class SentenceLexRankSum512_3(Transformation):
    # Embed document as Spacy object and sample one sentence as claim
    def __init__(self, min_sent_len=8):
        super().__init__()
        self.min_sent_len = min_sent_len
        self.count_at_random = 0
        self.model_sentence_transf = SentenceTransformer('paraphrase-mpnet-base-v2')

    def transform(self, example):
        assert example["text"] is not None, "Text must be available"

        # split into sentences
        page_id = example["id"]
        page_text = example["text"].replace("\n", " ")
        summaries = example["summary"].replace("\n", " ")


        tok = bert_tokenizer(page_text, return_length=True, max_length=512, truncation=True)
        page_text_detok = bert_tokenizer.decode(tok.input_ids, skip_special_tokens=True)

        page_doc = self.spacy(page_text, disable=["tagger"])
        sents = [sent for sent in page_doc.sents if len(sent) >= self.min_sent_len]

        page_sum = self.spacy(summaries, disable=["tagger"])
        sents_sum = [sent for sent in page_sum.sents]

        # print(sents)

        sents_text = [sent.text for sent in sents]

        sents_text_sum = [sent.text for sent in sents_sum]

        assert len(sents) == len(sents_text)

        signal.alarm(10)

        try:

            embs = self.model_sentence_transf.encode(sents_text, convert_to_tensor=True)

            embs_sum = self.model_sentence_transf.encode(sents_text_sum, convert_to_tensor=True)

            # Compute the pair-wise cosine similarities
            cos_scores = util.pytorch_cos_sim(embs, embs_sum).cpu().detach().numpy()

            centrality_scores = np.mean(cos_scores, axis=1)

            # import pdb
            # pdb.set_trace()

            # print(cos_scores)

            # Compute the centrality for each sentence
            #centrality_scores = degree_centrality_scores(cos_scores, threshold=None)

            # We argsort so that the first element is the sentence with the highest score
            most_central_sentence_indices = np.argsort(-centrality_scores)

            #assert len(most_central_sentence_indices) == len(sents)

            claims = []
            best_sents = {}
            for idx in most_central_sentence_indices:
                tmp_claim = sents[idx]
                claim_text = tmp_claim.text
                best_sents[idx] = claim_text

                tok = bert_tokenizer(claim_text, return_length=True, max_length=512, truncation=True)
                claim_text = bert_tokenizer.decode(tok.input_ids, skip_special_tokens=True)

                if claim_text in page_text_detok:
                    if len(claims) < 4:
                        claim = sents[idx]
                        claims.append(claim)

                if len(best_sents) == 5:
                    break

            best_sents_list = []
            d = collections.OrderedDict(sorted(best_sents.items()))
            for k, v in d.items():
                best_sents_list.append(v)


        except TimeoutException:
            self.count_at_random += 1
            print("return none error", self.count_at_random)
            return None
        else:
            # Reset the alarm
            signal.alarm(0)
            if not claims:
                self.count_at_random += 1
                print("return none, no textrank claim in 512", self.count_at_random)
                return None

        # print(claim)
        # exit()

        # sample claim
        new_examples = []
        for claim in claims:
            new_example = make_new_example(eid=page_id, text=page_doc.text,
                                           positive_example=claim.text, sentences=best_sents_list)
            new_examples.append(new_example)
        return new_examples



class SentenceLexRank512(Transformation):
    # Embed document as Spacy object and sample one sentence as claim
    def __init__(self, min_sent_len=8):
        super().__init__()
        self.min_sent_len = min_sent_len
        self.count_at_random = 0
        self.model_sentence_transf = SentenceTransformer('paraphrase-mpnet-base-v2')

    def transform(self, example):
        assert example["text"] is not None, "Text must be available"

        # split into sentences
        page_id = example["id"]
        page_text = example["text"].replace("\n", " ")

        tok = bert_tokenizer(page_text, return_length=True, max_length=512, truncation=True)
        page_text_detok = bert_tokenizer.decode(tok.input_ids, skip_special_tokens=True)

        page_doc = self.spacy(page_text, disable=["tagger"])
        sents = [sent for sent in page_doc.sents if len(sent) >= self.min_sent_len]

        # print(sents)

        sents_text = [sent.text for sent in sents]

        assert len(sents) == len(sents_text)

        signal.alarm(10)

        try:

            embs = self.model_sentence_transf.encode(sents_text, convert_to_tensor=True)

            # Compute the pair-wise cosine similarities
            cos_scores = util.pytorch_cos_sim(embs, embs).cpu().detach().numpy()

            # print(cos_scores)

            # Compute the centrality for each sentence
            centrality_scores = degree_centrality_scores(cos_scores, threshold=None)

            # We argsort so that the first element is the sentence with the highest score
            most_central_sentence_indices = np.argsort(-centrality_scores)

            claims = []
            best_sents = {}
            for idx in most_central_sentence_indices:
                tmp_claim = sents[idx]
                claim_text = tmp_claim.text
                best_sents[idx] = claim_text

                tok = bert_tokenizer(claim_text, return_length=True, max_length=512, truncation=True)
                claim_text = bert_tokenizer.decode(tok.input_ids, skip_special_tokens=True)

                if claim_text in page_text_detok:
                    if len(claims) < 4:
                        claim = sents[idx]
                        claims.append(claim)

                if len(best_sents) == 10:
                    break

            best_sents_list = []
            d = collections.OrderedDict(sorted(best_sents.items()))
            for k, v in d.items():
                best_sents_list.append(v)


        except TimeoutException:
            self.count_at_random += 1
            print("return none error", self.count_at_random)
            return None
        else:
            # Reset the alarm
            signal.alarm(0)
            if not claims:
                self.count_at_random += 1
                print("return none, no textrank claim in 512", self.count_at_random)
                return None

        # print(claim)
        # exit()

        # sample claim
        new_examples = []
        for claim in claims:
            new_example = make_new_example(eid=page_id, text=page_doc.text,
                                           positive_example=claim.text, sentences=best_sents_list)
            new_examples.append(new_example)
        return new_examples


class NegateSentences(Transformation):
    # Apply or remove negation from negatable tokens
    def __init__(self):
        super().__init__()
        self.__negatable_tokens = ("are", "is", "was", "were", "have", "has", "had",
                                   "do", "does", "did", "can", "ca", "could", "may",
                                   "might", "must", "shall", "should", "will", "would")

    def transform(self, example):

        new_claim, aug_span = self.__negate_sentences(self.spacy(example["positive_example"]))

        if new_claim:
            new_neg = {"text": new_claim.text, "augmentation": self.__class__.__name__, "augmentation_span": aug_span}
            if 'negative_examples' not in example:
                example['negative_examples'] = []
            example['negative_examples'].append(new_neg)
            return example
        else:
            return None

    def __negate_sentences(self, claim):
        # find negatable token, return None if no candiates found
        candidate_tokens = [token for token in claim if token.text in self.__negatable_tokens]

        if not candidate_tokens:
            return None, None

        # choose random token to negate
        negated_token = random.choice(candidate_tokens)
        negated_ix = negated_token.i
        doc_len = len(claim)

        if negated_ix > 0:
            if claim[negated_ix - 1].text in self.__negatable_tokens:
                negated_token = claim[negated_ix - 1]
                negated_ix = negated_ix - 1

        # check whether token is negative
        is_negative = False
        if (doc_len - 1) > negated_ix:
            if claim[negated_ix + 1].text in ["not", "n't"]:
                is_negative = True
            elif claim[negated_ix + 1].text == "no":
                return None, None

        # negate token
        claim_tokens = [token.text_with_ws for token in claim]
        if is_negative:
            if claim[negated_ix + 1].text.lower() == "n't":
                if claim[negated_ix + 1].text.lower() == "ca":
                    claim_tokens[negated_ix] = "can" if claim_tokens[negated_ix].islower() else "Can"
                claim_tokens[negated_ix] = claim_tokens[negated_ix] + " "
            claim_tokens.pop(negated_ix + 1)
        else:
            if claim[negated_ix].text.lower() in ["am", "may", "might", "must", "shall", "will"]:
                negation = "not "
            else:
                negation = random.choice(["not ", "n't "])

            if negation == "n't ":
                if claim[negated_ix].text.lower() == "can":
                    claim_tokens[negated_ix] = "ca" if claim_tokens[negated_ix].islower() else "Ca"
                else:
                    claim_tokens[negated_ix] = claim_tokens[negated_ix][:-1]
            claim_tokens.insert(negated_ix + 1, negation)

        # create new claim object
        new_claim = self.spacy("".join(claim_tokens))
        augmentation_span = (negated_ix, negated_ix if is_negative else negated_ix + 1)

        if new_claim.text == claim.text:
            return None, None
        else:
            return new_claim, augmentation_span


class PronounSwap(Transformation):
    # Swap randomly chosen pronoun
    def __init__(self, prob_swap=0.5):
        super().__init__()

        self.class2pronoun_map = {
            "SUBJECT": ["you", "he", "she", "we", "they"],
            "OBJECT": ["me", "you", "him", "her", "us", "them"],
            "POSSESSIVE": ["my", "your", "his", "her", "its", "out", "your", "their"],
            "REFLEXIVE": ["myself", "yourself", "himself", "itself", "outselves", "yourselves", "themselves"]
        }

        self.pronoun2class_map = {pronoun: key for (key, values) in self.class2pronoun_map.items() for pronoun in
                                  values}
        self.pronouns = {pronoun for (key, values) in self.class2pronoun_map.items() for pronoun in values}

    def transform(self, example):

        new_claim, aug_span = self.__swap_pronouns(self.spacy(example["positive_example"]))

        if new_claim:
            new_neg = {"text": new_claim.text, "augmentation": self.__class__.__name__, "augmentation_span": aug_span}
            if 'negative_examples' not in example:
                example['negative_examples'] = []
            example['negative_examples'].append(new_neg)
            return example
        else:
            return None

    def __swap_pronouns(self, claim):
        # find pronouns
        claim_pronouns = [token for token in claim if token.text.lower() in self.pronouns]

        if not claim_pronouns:
            return None, None

        # find pronoun replacement
        chosen_token = random.choice(claim_pronouns)
        chosen_ix = chosen_token.i
        chosen_class = self.pronoun2class_map[chosen_token.text.lower()]

        candidate_tokens = [token for token in self.class2pronoun_map[chosen_class] if
                            token != chosen_token.text.lower()]

        if not candidate_tokens:
            return None, None

        # swap pronoun and update indices
        swapped_token = random.choice(candidate_tokens)
        swapped_token = align_ws(chosen_token.text_with_ws, swapped_token)
        swapped_token = swapped_token if chosen_token.text.islower() else swapped_token.capitalize()

        claim_tokens = [token.text_with_ws for token in claim]
        claim_tokens[chosen_ix] = swapped_token

        # create new claim object
        new_claim = self.spacy("".join(claim_tokens))
        augmentation_span = (chosen_ix, chosen_ix)

        if claim.text == new_claim.text:
            return None, None
        else:
            return new_claim, augmentation_span


class NERSwap(Transformation):
    # Swap NER objects - parent class
    def __init__(self):
        super().__init__()
        self.categories = ()

    def transform(self, example):

        new_claim, aug_span = self.__swap_entities(self.spacy(example["text"]), self.spacy(example["positive_example"]))

        if new_claim:
            new_neg = {"text": new_claim.text, "augmentation": self.__class__.__name__, "augmentation_span": aug_span}
            if 'negative_examples' not in example:
                example['negative_examples'] = []
            example['negative_examples'].append(new_neg)
            return example
        else:
            return None

    def __swap_entities(self, text, claim):
        # find entities in given category
        text_ents = [ent for ent in text.ents if ent.label_ in self.categories]
        claim_ents = [ent for ent in claim.ents if ent.label_ in self.categories]

        if not claim_ents or not text_ents:
            return None, None

        # choose entity to replace and find possible replacement in source
        replaced_ent = random.choice(claim_ents)
        candidate_ents = [ent for ent in text_ents if
                          ent.text != replaced_ent.text and ent.text not in replaced_ent.text and replaced_ent.text not in ent.text]

        if not candidate_ents:
            return None, None

        # update claim and indices
        swapped_ent = random.choice(candidate_ents)
        claim_tokens = [token.text_with_ws for token in claim]
        swapped_token = align_ws(replaced_ent.text_with_ws, swapped_ent.text_with_ws)
        claim_swapped = claim_tokens[:replaced_ent.start] + [swapped_token] + claim_tokens[replaced_ent.end:]

        # create new claim object
        new_claim = self.spacy("".join(claim_swapped))
        augmentation_span = (replaced_ent.start, replaced_ent.start + len(swapped_ent) - 1)

        if new_claim.text == claim.text:
            return None, None
        else:
            return new_claim, augmentation_span


class EntitySwap(NERSwap):
    # NER swapping class specialized for entities (people, companies, locations, etc.)
    def __init__(self):
        super().__init__()
        self.categories = ("PERSON", "ORG", "NORP", "FAC", "GPE", "LOC", "PRODUCT",
                           "WORK_OF_ART", "EVENT")


class NumberSwap(NERSwap):
    # NER swapping class specialized for numbers (excluding dates)
    def __init__(self):
        super().__init__()

        self.categories = ("PERCENT", "MONEY", "QUANTITY", "CARDINAL")


class DateSwap(NERSwap):
    # NER swapping class specialized for dates and time
    def __init__(self):
        super().__init__()

        self.categories = ("DATE", "TIME")


class AddNoise(Transformation):
    # Inject noise into claims
    def __init__(self, noise_prob=0.05, delete_prob=0.8):
        super().__init__()

        self.noise_prob = noise_prob
        self.delete_prob = delete_prob
        self.spacy = spacy.load("en_core_web_sm")

    def transform(self, example):
        assert example["text"] is not None, "Text must be available"
        assert example["claim"] is not None, "Claim must be available"

        new_example = dict(example)
        #claim = new_example["claim"]
        claim = self.spacy(new_example["claim"])
        aug_span = new_example["augmentation_span"]
        new_claim, aug_span = self.__add_noise(claim, aug_span)

        if new_claim:
            new_example["claim"] = new_claim
            new_example["augmentation_span"] = aug_span
            new_example["noise"] = True
            return new_example
        else:
            return None

    def __add_noise(self, claim, aug_span):
        claim_tokens = [token.text_with_ws for token in claim]

        new_claim = []
        for ix, token in enumerate(claim_tokens):
            # don't modify text inside an augmented span
            apply_augmentation = True
            if aug_span:
                span_start, span_end = aug_span
                if span_start <= ix <= span_end:
                    apply_augmentation = False

            # decide whether to add noise
            if apply_augmentation and random.random() < self.noise_prob:
                # decide whether to replicate or delete token
                if random.random() < self.delete_prob:
                    # update spans and skip token
                    if aug_span:
                        span_start, span_end = aug_span
                        if ix < span_start:
                            span_start -= 1
                            span_end -= 1
                        aug_span = span_start, span_end
                    if len(new_claim) > 0:
                        if new_claim[-1][-1] != " ":
                            new_claim[-1] = new_claim[-1] + " "
                    continue
                else:
                    if aug_span:
                        span_start, span_end = aug_span
                        if ix < span_start:
                            span_start += 1
                            span_end += 1
                        aug_span = span_start, span_end
                    new_claim.append(token)
            new_claim.append(token)
        new_claim = self.spacy("".join(new_claim))

        if claim.text == new_claim.text:
            return None, None
        else:
            return new_claim, aug_span
