 #
 #     BenchIE: A Framework for Multi-Faceted Fact-Based Open Information Extraction Evaluation
 #
 #        File:  gold_annotations.py
 #
 #     Authors: Deleted for purposes of anonymity
 #
 #     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 #
 # The software and its source code contain valuable trade secrets and shall be maintained in
 # confidence and treated as confidential information. The software may only be used for
 # evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 # license agreement or nondisclosure agreement with the proprietor of the software.
 # Any unauthorized publication, transfer to third parties, or duplication of the object or
 # source code---either totally or in part---is strictly prohibited.
 #
 #     Copyright (c) 2021 Proprietor: Deleted for purposes of anonymity
 #     All Rights Reserved.
 #
 # THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR
 # IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY
 # AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT
 # DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION.
 #
 # NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 # IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE
 # LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 # FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 # OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 # ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 # TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 # THE POSSIBILITY OF SUCH DAMAGES.
 #
 # For purposes of anonymity, the identity of the proprietor is not given herewith.
 # The identity of the proprietor will be given once the review of the
 # conference submission is completed.
 #
 # THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 #

import pdb
import re
import itertools

import numpy as np

from utils import triple2string

class GoldAnnotations:
    def __init__(self):
        """
            Default constructor creates empty list for the sentences ID, and empty dictionaries for the input sentences and 
            the golden annotations
        """
        # Store sentence IDs written in the golden annotation input text file
        self.sent_ids = [] 

        # Store sentences as a dictionary, where key = sent id; value = sentence as a string
        self.sentences = {}

        # Store the golden annotations in a dictionary, where key = sent id; value = list of lists, where each list is 
        # a triple synset, and each element is a triple
        self.golden_annotations = {}

    def load_gold_annotations(self, filename: str, load_mode: str = "full"):
        """
            Loads the input sentences IDs (in self.sent_ids), input sentences (in self.sentences) and the dictionary of golden 
            annotations (in self.gold_annotations) from the golden annotations written in filename.

            Args
            ----
            filename: str
                The name of the file where the golden annotations are written
            load_mode: str
                The mode of loading the data. There are only two possible options: "full" or "minimal" ("full" is the default).
                In "full" mode, everything in the golden annotations is loaded. In "minimal" mode, only the tokens that are not optional
                are loaded (the optional tokens are omitted). 
        """

        # Load the sentences
        self.__load_sentences(filename)
        
        # list of extraction annotations
        extraction_annotations = self.__get_extraction_annotations(filename)
        
        # cluster gold annotations into sentence-level.
        # Output a list: ['1--> Cluster 1:', ['He', 'served as', '[the] [first] Prime Minister [of Australia]'], ['He', 'served', 'as [the] [first] Prime Minister [of Australia]'], '1--> Cluster 2:', ... 
        extraction_annotations_per_sentence = self.__get_extraction_annotations_per_sentence(extraction_annotations)

        if load_mode == "full":
            # Create a gold dictionary, where the key is the sentence ID, the value is a list (i.e., a cluster, a triple synset) of golden extractions 
            self.golden_annotations = self.__generate_golden_extractions(extraction_annotations_per_sentence)
            self.__generate_golden_extractions_from_optional_tokens(self.golden_annotations)
        elif load_mode == "minimal":
            for sentence in extraction_annotations_per_sentence:
                for entry in sentence:
                    if type(entry) is list:
                        triple = entry
                        for i in range(len(triple)):
                            if "[" in triple[i]:
                                # Remove tokens/phrases in square brackets
                                triple[i] = re.sub("[\(\[].*?[\)\]]", "", triple[i]).strip()

                                # Replace multiple spaces with one
                                triple[i] = re.sub(' +', ' ', triple[i])
            
            # Create a gold dictionary, where the key is the sentence ID, the value is a list (i.e., a cluster, a triple synset) of golden extractions 
            self.golden_annotations = self.__generate_golden_extractions(extraction_annotations_per_sentence)
            self.__generate_golden_extractions_from_optional_tokens(self.golden_annotations)            
        else:
            raise Exception

    def __load_sentences(self, filename: str):
        """
            Loads the input sentences (in self.sentences) and their IDs (in self.sent_ids) from the golden annotations written in filename

            Args
            ----
            filename: str
                The name of the file where the golden annotations are written
        """

        with open(filename,'r',encoding='UTF-8') as f:
            file_lines = [line.strip() for line in f]

        for i in range(0, len(file_lines)):
            line = file_lines[i]

            if "sent_id:" in line:
                sent_id = line.split("\t")[0].split("sent_id:")[1]
                sentence = line.split("\t")[1]
                if sent_id in self.sent_ids:
                    raise Exception
                self.sent_ids.append(sent_id)
                self.sentences[sent_id] = sentence

    def __get_extraction_annotations(self, filename: str) -> list:
        """
            Get a list of file lines where the gold annotations are written (i.e., filter out the file lines where there are either 
            sentences or empty lines).

            Args
            ----
            filename: str
                The name of the file where the golden annotations are written

            Returns
            -------
            gold: list 
                list of elements. Each element is either a string (triple cluster name) or a list of slots (subj, rel, obj). 
                The triple cluster name is written in the format '1--> Cluster 2:' where 1 is the sent ID and 2 is the cluster ID.
                The list of slots is a triple, where each element is a string. The tokens in square brackets signify optional tokens.
        """

        with open(filename,'r',encoding='UTF-8') as f:
            file_lines = [line.strip() for line in f]
        
        gold = []
        for line in file_lines:
            if ' --> ' in line:
                gold.append(re.split(r' --> ', line))   #add extraction line
            elif "sent_id:" in line:
                continue    #remove sentence line
            else:
                gold.append(line)   # add cluster order
        
        return gold

    def __get_extraction_annotations_per_sentence(self, extraction_annotations: list) -> list:
        """
            Args
            ----
            extraction_annotations: list
                List of extraction annotations, which were obtained with the __get_extraction_annotations function. For details,
                see the explanation for the returned variable gold in the function __get_extraction_annotations.

            Returns
            -------
            extractions_per_sentence: list
                List that contains details about the golden annotations (e.g., cluster and triples). The number of the elements of the 
                list corresponds to the number of input sentences. In other words, this function restructures the data returned from 
                the function __get_extraction_annotations
                
        """

        extractions_per_sentence=[]
        length=0
        for i in range(len(extraction_annotations)):
            if extraction_annotations[i]=='':
                sublist=extraction_annotations[length:i]
                length=len(sublist)
                extractions_per_sentence.append(sublist)
                length=i+1
            if i==len(extraction_annotations)-1:
                sublist=extraction_annotations[length:]
                extractions_per_sentence.append(sublist)
            else:
                continue
        
        return extractions_per_sentence

    def __generate_golden_extractions(self, gold_sentence_level: list) -> dict:
        """
            Create a gold dictionary, where the key is the sentence ID, the value is a list (i.e., a cluster, a triple synset) of 
            golden extractions. Tokens written with square brackets are optional.

            Args
            ----
                gold_sentence_level: list

            Returns
            -------
                gold_dict: dict
                    key: sent id; value: list (cluster) of extractions; in each slot, the tokens in square brackets are optional

        """
        gold_cluster_level = []

        for s in gold_sentence_level:
            cluster=[]
            length=0
            for i in range(1,len(s)):
                if 'Cluster' in s[i]:
                    sublist=s[length:i]
                    length=len(sublist)
                    cluster.append(sublist)
                    length=i
                if i==len(s)-1:
                    sublist=s[length:]
                    cluster.append(sublist)
                else:
                    continue
            for i in cluster:
                for j in i:
                    if 'Cluster' in j:
                        i.remove(j)
            gold_cluster_level.append(cluster)

        gold_dict={}

        for (i, j) in zip(self.sent_ids, gold_cluster_level):
            gold_dict[i]=j
        
        return gold_dict

    def __generate_golden_extractions_from_optional_tokens(self, gold_dict: dict):
        """
            Generate all possible golden extractions from the optional tokens (in square brackets). Save results in gold_dict.

            Args 
            ----
                gold_dict: key: sent id; value: list (cluster) of extractions; in each slot, the tokens in square brackets are optional.
                            Once the function is executed, the changes are stored in gold_dict.

        """

        for sentence in gold_dict:
            for cluster in gold_dict[sentence]:
                gold_annotation=[]
                for triple in cluster:
                    sublist=[]
                    for slot in triple:
                        new_slot=re.split(r' ',slot)
                        new_list=[]
                        it=iter(range(len(new_slot)))
                        for i in it:
                            if '[' not in new_slot[i] and ']' not in new_slot[i]:
                                new_list.append(new_slot[i])
                            elif '[' in new_slot[i] and ']' in new_slot[i]:
                                new_list.append(new_slot[i])
                            elif '[' in new_slot[i]:
                                for j in range(i+1,len(new_slot)):
                                    if ']' in new_slot[j]:
                                        index=j
                                        new_string=' '.join(new_slot[i:index+1])
                                        new_list.append(new_string)
                                        for times in range(int(index-i-1)):
                                            next(it)
                                        break
                        sublist.append(new_list)
                    gold_annotation.append(sublist)
                cluster.clear()
                for c in gold_annotation:
                    cluster.append(c)
                    
        for sentence in gold_dict:
            for cluster in gold_dict[sentence]:
                cluster_annotation=[]
                for triple in cluster:
                    try:
                        sub,pred,obj=triple[0:3]
                    except:
                        pdb.set_trace()
                    g_subject=self.__generate_slots_from_optional_tokens(sub)
                    g_predicate=self.__generate_slots_from_optional_tokens(pred)
                    g_object=self.__generate_slots_from_optional_tokens(obj)
                    #new triple form
                    t=[]
                    t.append(g_subject)
                    t.append(g_predicate)
                    t.append(g_object)
                    a=list(itertools.product(*t))
                    for i in a:
                        cluster_annotation.append(i)
                cluster.clear()
                for j in cluster_annotation:
                    cluster.append(j)
                    
        # remove brackets in the triple
        for sentence in gold_dict:
            for cluster in gold_dict[sentence]:
                for triple in range(len(cluster)):
                    new_triple=[]
                    for slot in range(len(cluster[triple])):
                        slot_merge=' '.join(cluster[triple][slot])   #merge words in every slot into a single string
                        new_slot=re.sub('\\[|\\]', '', slot_merge)   #delete brackets
                        new_triple.append(new_slot)
                    cluster[triple]=new_triple
    
    def __generate_slots_from_optional_tokens(self, slot: list) -> list:
        """
            Given a slot (a list of tokens) with potentially optional tokens (in square brackets), generate all possible combinations
            of the surface representation of the slot.

            Args
            ----
                slot: list
                    List of tokens with potentially optional tokens (in square brackets)

            Returns
            -------
                sub_all: list
                    List of all possible surface realizations of the slot.
        """
        index=[]
        for i in range(len(slot)):
            if '[' in slot[i]:
                index.append(i)
        j=[]
        for i in range(1, len(index)+1):
            els = [list(x) for x in itertools.combinations(index, i)]
            j.extend(els)
        for i in j:
            i.reverse()
        sub_all=[]
        sub_all.append(slot)
        
        for i in j:
            p=slot.copy()
            for j in i:
                del p[j]
            sub_all.append(p)
        
        return sub_all

    def compute_stats(self) -> dict:
        """
            Compute certain stats about the golden annotations. 

            Returns:
            --------
                stats: dict
                    Dictionary, where the key is the name of the statistics; value is the value of that particular statistic
                    A list of all possible stats:
                    * triple_synset_count: Total number of triple synset
                    * extractions_count: Total number of extractions
                    * extractions_per_sentence: Number of extractions per sentence
                    * triple_synsets_per_sentence: Number of triple synsets per sentence
                    * extractions_per_triple_synset: Number of extractions per triple synset
                    * avg_extraction_length: Avg. length of extraction
        """
        stats = {}
        
        triple_synset_count = 0
        extractions_count = 0
        sent_count = len(self.sentences)
        extraction_length = []

        for sent_id in self.golden_annotations:
            triple_synset_count += len(self.golden_annotations[sent_id])
            for synset in self.golden_annotations[sent_id]:
                extractions_set = set()
                for triple in synset:
                    triple_length = 0
                    triple_length += len(triple[0].split(" "))
                    triple_length += len(triple[1].split(" "))
                    triple_length += len(triple[2].split(" "))
                    extraction_length.append(triple_length)
                    extractions_set.add(triple2string(triple))
                extractions_count += len(extractions_set)
        
        stats['triple_synset_count'] = triple_synset_count
        stats['extractions_count'] = extractions_count
        stats['extractions_per_sentence'] = extractions_count/sent_count
        stats['triple_synsets_per_sentence'] = triple_synset_count/sent_count
        stats['extractions_per_triple_synset'] = extractions_count/triple_synset_count
        stats['avg_extraction_length'] = np.mean(extraction_length)

        return stats

    def get_subset_annotations(self, sent_ids: list):
        """
            Return a subset of the golden annotations.

            Args
            ----
                sent_ids: list
                    List of sentence ids

            Returns
            -------
                subset: GoldenAnnotations
                    The subset of the golden annotation
        """
        subset = GoldAnnotations()
        
        for s_id in sent_ids:
            subset.sent_ids.append(s_id)
            subset.sentences[s_id] = self.sentences[s_id]
            subset.golden_annotations[s_id] = self.golden_annotations[s_id]
        
        return subset