 #
 #     BenchIE: A Framework for Multi-Faceted Fact-Based Open Information Extraction Evaluation
 #
 #        File:  scores.py
 #
 #     Authors: Deleted for purposes of anonymity
 #
 #     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 #
 # The software and its source code contain valuable trade secrets and shall be maintained in
 # confidence and treated as confidential information. The software may only be used for
 # evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 # license agreement or nondisclosure agreement with the proprietor of the software.
 # Any unauthorized publication, transfer to third parties, or duplication of the object or
 # source code---either totally or in part---is strictly prohibited.
 #
 #     Copyright (c) 2021 Proprietor: Deleted for purposes of anonymity
 #     All Rights Reserved.
 #
 # THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR
 # IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY
 # AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT
 # DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION.
 #
 # NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 # IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE
 # LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 # FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 # OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 # ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 # TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 # THE POSSIBILITY OF SUCH DAMAGES.
 #
 # For purposes of anonymity, the identity of the proprietor is not given herewith.
 # The identity of the proprietor will be given once the review of the
 # conference submission is completed.
 #
 # THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 #

import numpy as np
import pdb

class Scores():
    def __init__(self) -> None:
        self.precision = 0.0
        self.recall = 0.0
        self.f1 = 0.0

    def compute_precision(self, extractions: list, golden_annotations: dict, match_type: str):
        """
            Args
            ----
                extractions: list
                    List of triples, each of them written in the following format: [sent_id, subj, rel, obj]
                golden_annotations: dict
                    key: sentence id (sent_id from triple)
                    value: list of lists. Each list represents a triple synset (i.e., list of triples having the same meaning).
                match_type: str
                    Matching type for true positive. Can be either "slot" or "lexical" for per-slot matching and lexical matching respectively.
        """
        # Compute TPs
        true_pos = 0
        scores = self.__compute_scores(extractions, golden_annotations, match_type)
        
        for sent_id in scores:
            true_pos += np.count_nonzero(scores[sent_id])
            
        # Compute FPs
        false_pos = 0
        for triple in extractions:
            sent_id = triple[0]
            if sent_id in golden_annotations:
                if match_type == "slot":
                    if not self.is_true_positive(triple, golden_annotations):
                        false_pos += 1
                elif match_type == "lexical":
                    if not self.is_true_positive_lexical(triple, golden_annotations):
                        false_pos += 1
                else:
                    raise Exception
        
        if true_pos == 0 and false_pos == 0:
            return 0.0
        
        self.precision = true_pos / (true_pos + false_pos)
        
    def compute_recall(self, extractions: list, golden_annotations: dict, match_type: str):
        """
            Computes the recall of the golden extractions w.r.t. the golden annotations

            Args
            ----
                extractions: list
                    List of triples, each of them written in the following format: [sent_id, subj, rel, obj]
                golden_annotations: dict
                    key: sentence id (sent_id from triple)
                    value: list of lists. Each list represents a triple synset (i.e., list of triples having the same meaning).
                match_type: str
                    Matching type for true positive. Can be either "slot" or "lexical" for per-slot matching and lexical matching respectively.
        """
        scores = self.__compute_scores(extractions, golden_annotations, match_type)
        true_pos = 0
        false_negs = 0
            
        # Compute TPs and FNs
        for sent_id in scores:
            false_negs += scores[sent_id].shape[0] - np.count_nonzero(scores[sent_id])
            true_pos += np.count_nonzero(scores[sent_id])
        
        self.recall = true_pos / (true_pos + false_negs)
    
    def compute_f1(self):
        """
            Computes the F1 score. Note that precision and recall should be already computed; otherwise it will compute f1 = 0.0
        """
        if (self.precision + self.recall) == 0:
            self.f1 = 0.0
        else:
            self.f1 = 2 * (self.precision * self.recall) / (self.precision + self.recall)

    def __compute_scores(self, extractions: list, golden_annotations: dict, match_type: str) -> dict:
        scores = self.__get_empty_synset_scores(golden_annotations)
        
        for triple in extractions:
            sent_id = triple[0]
            if sent_id in golden_annotations:
                if match_type == "slot":
                    if (self.is_true_positive(triple, golden_annotations)):
                        tp_synset_ind = self.__get_true_positive_synset_ind(triple, golden_annotations)
                        scores[sent_id][tp_synset_ind] = 1
                elif match_type == "lexical":
                    if (self.is_true_positive_lexical(triple, golden_annotations)):
                        tp_synset_ind = self.__get_lexical_true_positive_synset_ind(triple, golden_annotations)
                        scores[sent_id][tp_synset_ind] = 1
                else:
                    raise Exception
                    
        return scores

    def __get_empty_synset_scores(self, golden_annotations: dict) -> dict:
        synset_scores = {}
        for sent_id in golden_annotations:
            synset_scores[sent_id] = np.zeros(len(golden_annotations[sent_id]))
        return synset_scores

    def __get_true_positive_synset_ind(self, triple: list, golden_annotations: dict) -> int:
        """
            Args
            ----
                triple: list 
                    The object 'triple' is in the following format [sent_id, subj, rel, obj]
                golden_annotations: dict
                    key: sentence id (sent_id from triple)
                    value: list of lists. Each list represents a triple synset (i.e., list of triples having the same meaning).
            
            Returns:
                int: index of the triple synset, -1 if it is not TP
        """
        subj = triple[1]
        rel = triple[2]
        obj = triple[3]
        
        sent_id = triple[0]
        for i in range(0, len(golden_annotations[sent_id])):
            triple_synset = golden_annotations[sent_id][i]
            for tr in triple_synset:
                if subj == tr[0] and rel == tr[1] and obj == tr[2]:
                    return i
        return -1

    def __get_lexical_true_positive_synset_ind(self, triple: list, golden_annotations: dict) -> int:
        """
            Args
            ----
                triple: list 
                    The object 'triple' is in the following format [sent_id, subj, rel, obj]
                golden_annotations: dict
                    key: sentence id (sent_id from triple)
                    value: list of lists. Each list represents a triple synset (i.e., list of triples having the same meaning).
            
            Returns:
                int: index of the triple synset, -1 if it is not TP
        """
        subj = triple[1]
        rel = triple[2]
        obj = triple[3]
        
        sent_id = triple[0]
        for i in range(0, len(golden_annotations[sent_id])):
            triple_synset = golden_annotations[sent_id][i]
            for tr in triple_synset:
                triple_str = subj + " " + rel + " " + obj
                tr_str = tr[0].strip() + " " + tr[1].strip() + " " + tr[2].strip()
                if triple_str == tr_str:
                    return i
        
        return -1

    def is_true_positive(self, triple: list , golden_annotations: dict) -> bool:
        """
            Checks if an extracted triple is considered as true positive w.r.t. the golden annotations. Matching is performed on per-slot level.

            Args
            ----
                triple: list 
                    The object 'triple' is in the following format [sent_id, subj, rel, obj]
                golden_annotations: dict
                    key: sentence id (sent_id from triple)
                    value: list of lists. Each list represents a triple synset (i.e., list of triples having the same meaning).
            
            Returns
            -------
                boolean:
                    True: if the triple is found in the golden annotations
                    False: otherwise
        """
        subj = triple[1].strip()
        rel = triple[2].strip()
        obj = triple[3].strip()
        
        sent_id = triple[0]
        for triple_synset in golden_annotations[sent_id]:
            for tr in triple_synset:
                if subj == tr[0].strip() and rel == tr[1].strip() and obj == tr[2].strip():
                    return True
        
        return False

    def is_true_positive_lexical(self, triple: list, golden_annotations: dict) -> bool:
        """
            Checks if an extracted triple is considered as true positive w.r.t. the golden annotations. Matching is performed on lexical level.

            Args
            ----
                triple: list 
                    The object 'triple' is in the following format [sent_id, subj, rel, obj]
                golden_annotations: dict
                    key: sentence id (sent_id from triple)
                    value: list of lists. Each list represents a triple synset (i.e., list of triples having the same meaning).
            
            Returns
            -------
                boolean:
                    True: if the triple is found in the golden annotations
                    False: otherwise
        """
        subj = triple[1].strip()
        rel = triple[2].strip()
        obj = triple[3].strip()
        
        sent_id = triple[0]
        for triple_synset in golden_annotations[sent_id]:
            for tr in triple_synset:
                triple_str = subj + " " + rel + " " + obj
                tr_str = tr[0].strip() + " " + tr[1].strip() + " " + tr[2].strip()
                if triple_str == tr_str:
                    return True
        return False

    def print_scores(self, oie_system: str):
        """
            Print all scores (precision, recall, f1 score).

            Args
            ----
                oie_system: str
                    The name of the OIE system
        """
        
        print(oie_system + " precision: " + str(self.precision))
        print(oie_system + " recall: " + str(self.recall))
        print(oie_system + " f1: " + str(self.f1))
        print("===============")
