 #
 #     BenchIE: A Framework for Multi-Faceted Fact-Based Open Information Extraction Evaluation
 #
 #        File:  benchie.py
 #
 #     Authors: Deleted for purposes of anonymity
 #
 #     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 #
 # The software and its source code contain valuable trade secrets and shall be maintained in
 # confidence and treated as confidential information. The software may only be used for
 # evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 # license agreement or nondisclosure agreement with the proprietor of the software.
 # Any unauthorized publication, transfer to third parties, or duplication of the object or
 # source code---either totally or in part---is strictly prohibited.
 #
 #     Copyright (c) 2021 Proprietor: Deleted for purposes of anonymity
 #     All Rights Reserved.
 #
 # THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR
 # IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY
 # AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT
 # DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION.
 #
 # NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 # IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE
 # LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 # FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 # OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 # ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 # TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 # THE POSSIBILITY OF SUCH DAMAGES.
 #
 # For purposes of anonymity, the identity of the proprietor is not given herewith.
 # The identity of the proprietor will be given once the review of the
 # conference submission is completed.
 #
 # THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 #

from gold_annotations import GoldAnnotations
from oie_extractions import OIEExtractions

import numpy as np

from scores import Scores

import utils

import pdb

class Benchie:
    """ 
        The main class for BenchIE
    """

    def __init__(self):
        self.gold_annotations = GoldAnnotations()
        self.oie_system_extractions = {}
        self.scores = {}
        self.error_stats = {}

    def load_gold_annotations(self, filename: str, load_mode: str = "full"):
        """
            Load the gold annotations. For details, see goldannotations.load_gold_annotations().

            Args
            ----
                filename: str
                    The path of the file in which the gold annotations are written.
                load_mode: str
                    The mode of loading the data. There are only two possible options: "full" or "minimal" ("full" is the default).
                    In "full" mode, everything in the golden annotations is loaded. In "minimal" mode, only the tokens that are not optional
                    are loaded (the optional tokens are omitted). 
        """
        self.gold_annotations.load_gold_annotations(filename=filename, load_mode=load_mode)

    def add_oie_system_extractions(self, oie_system_name: str, filename: str):
        """
            Load and add the extractions produced by an OIE system (for details of loading, see oie_extractions.load_oie_extractions()).
            In the end, the extractions are loaded in self.oie_system_extractions where the key is the OIE system's name and the value is 
            a list of extractions.

            Args
            ----
                oie_system_name: str
                    The name of the OIE system. 
                filename: str
                    The path of the file where the OIE extractions of the particular OIE system are written
        """
        extractions = OIEExtractions()
        extractions.set_oie_system_name(oie_system_name)
        extractions.load_oie_extractions(filename)
        extractions.compute_stats()
        if oie_system_name in self.oie_system_extractions:
            raise Exception
        self.oie_system_extractions[oie_system_name] = extractions
        self.scores[oie_system_name] = Scores()

    def __compute_precision(self, oie_system_name: str, match_type: str):
        extractions = self.oie_system_extractions[oie_system_name].extractions
        gold_annotations = self.gold_annotations.golden_annotations
        self.scores[oie_system_name].compute_precision(extractions, gold_annotations, match_type)

    def __compute_recall(self, oie_system_name: str, match_type: str):
        extractions = self.oie_system_extractions[oie_system_name].extractions
        gold_annotations = self.gold_annotations.golden_annotations
        self.scores[oie_system_name].compute_recall(extractions, gold_annotations, match_type)

    def __compute_f1(self, oie_system_name: str):
        self.scores[oie_system_name].compute_f1()

    def compute_precision(self, match_type: str = "slot"):
        """
            Compute the precision for each OIE system w.r.t. the golden annotations.

            Args
            ----
                match_type: str
                    Matching type for true positive. Can be either "slot" or "lexical" for per-slot matching and lexical matching respectively.
                    The default is "slot".
        """
        for oie_system in self.oie_system_extractions:
            self.__compute_precision(oie_system, match_type)

    def compute_recall(self, match_type: str = "slot"):
        """
            Compute recall for each OIE system w.r.t. the golden annotations.

            Args
            ----
                match_type: str
                    Matching type for true positive. Can be either "slot" or "lexical" for per-slot matching and lexical matching respectively.
                    The default is "slot".
        """
        for oie_system in self.oie_system_extractions:
            self.__compute_recall(oie_system, match_type)

    def compute_f1(self):
        """
            Compute the F1 score for each OIE system w.r.t. the golden annotations. Note that the precision and recall 
            should already be calculated upfront, otherwise, it will return a score of 0.
        """
        for oie_system in self.oie_system_extractions:
            self.__compute_f1(oie_system)

    def print_scores(self):
        """
            Print all scores (precision, recall, f1 score) for all OIE systems. Note that if you didn't previously invoke the functions 
            compute_precision(), compute_recall() and compute_f1(), this function will print 0.0 as results.
        """
        for oie_system in self.oie_system_extractions:
            print(oie_system + " precision: " + str(round(self.scores[oie_system].precision, 2)))
            print(oie_system + " recall: " + str(round(self.scores[oie_system].recall, 2)))
            print(oie_system + " f1: " + str(round(self.scores[oie_system].f1, 2)))
            print("===============")

    def get_subset(self, sent_ids):
        """
            Given a list of sentence IDs, return a subset of BenchIE with OIE extractions and golden annotations that are from the 
            sentences provided in sent_ids.

            Args
            ----
            sent_ids: list
                A list of sentence IDs. This should be a subset of the already existing IDs
            
            Returns
            -------
            subset: BenchIE
                A BenchIE object containing only the OIE extractions and gold annotations from the sentences matching sent_ids
        """
        
        subset = Benchie()
        
        subset.gold_annotations = self.gold_annotations.get_subset_annotations(sent_ids)
        for oie_system in self.oie_system_extractions:
            oie_extractions_subset = self.oie_system_extractions[oie_system].get_subset_extractions(sent_ids)
            subset.oie_system_extractions[oie_system] = oie_extractions_subset
            subset.scores[oie_system] = Scores()

        return subset

    def compute_error_stats(self):
        """
            Compute error stats for every OIE system w.r.t. BenchIE. This works for slot matching only, not for lexical matching.

            Errors are written in self.error_stats, which is a dictionary with key: OIE system; value: dict of stats. Each dict of stats has 
            the following entries (which are strings):
                * "(0, 0, 0)" -> fraction of errors in every slot
                * "(0, 0, 1)" -> fraction of errors in subj and rel 
                * "(0, 1, 0)" -> fraction of errors in subj and object 
                * "(0, 1, 1)" -> fraction of errors in subj only       
                * "(1, 0, 0)" -> fraction of errors in rel and object  
                * "(1, 0, 1)" -> fraction of errors in rel only
                * "(1, 1, 0)" -> fraction of errors in obj only        
                * "error_count" -> number of triples that have at least one error in at least one slot
                * "subj_err_count" -> number of triples that have at least one error in the subject
                * "subj_err_frac" -> fraction of triples (subj_err_count / error_count) that have at least one error in the subject
                * "rel_err_count" -> number of triples that have at least one error in the relation
                * "rel_err_frac" -> fraction of triples (rel_err_count / error_count) that have at least one error in the relation
                * "obj_err_count" -> number of triples that have at least one error in the object
                * "obj_err_frac" -> fraction of triples (obj_err_count / error_count) that have at least one error in the object
        """
        for oie_system in list(self.oie_system_extractions.keys()):
            # Initializations
            self.error_stats[oie_system] = {}
            openie_extractions = self.oie_system_extractions[oie_system].extractions
            golden_annotations = self.gold_annotations.golden_annotations
            
            # Get mismatch slots
            max_mismatch_slots = utils.get_max_slot_mismatches(openie_extractions, golden_annotations)
            error_count = len(max_mismatch_slots)
            
            # Per-slot mismatches
            slots_match_count = np.sum(max_mismatch_slots, axis=0).astype(int)
            subj_error_count = error_count - slots_match_count[0]
            rel_error_count = error_count - slots_match_count[1]
            obj_error_count = error_count - slots_match_count[2]
            
            # Slot-pair mismatches
            # index 0: errors in every slot      (0, 0, 0)
            # index 1: errors in subj and rel    (0, 0, 1)
            # index 2: errors in subj and object (0, 1, 0)
            # index 3: errors in subj only       (0, 1, 1)
            # index 4: errors in rel and object  (1, 0, 0)
            # index 5: errors in rel only        (1, 0, 1)
            # index 6: errors in obj only        (1, 1, 0)
            # index 7: no errors                 (1, 1, 1) -> should always be zero
            slot_error_counts = np.zeros(8).astype(int)
            for slots in max_mismatch_slots:
                if slots[0] == 0 and slots[1] == 0 and slots[2] == 0:
                    slot_error_counts[0] += 1
                elif slots[0] == 0 and slots[1] == 0 and slots[2] == 1:
                    slot_error_counts[1] += 1
                elif slots[0] == 0 and slots[1] == 1 and slots[2] == 0:
                    slot_error_counts[2] += 1
                elif slots[0] == 0 and slots[1] == 1 and slots[2] == 1:
                    slot_error_counts[3] += 1
                elif slots[0] == 1 and slots[1] == 0 and slots[2] == 0:
                    slot_error_counts[4] += 1
                elif slots[0] == 1 and slots[1] == 0 and slots[2] == 1:
                    slot_error_counts[5] += 1
                elif slots[0] == 1 and slots[1] == 1 and slots[2] == 0:
                    slot_error_counts[6] += 1
                elif slots[0] == 1 and slots[1] == 1 and slots[2] == 1:
                    slot_error_counts[7] += 1
            errors_sum = np.sum(slot_error_counts)
            
            # Tests
            if slot_error_counts[7] > 0:
                print("WARNING: There is a triple conidered an error which is correct")
            if errors_sum != len(max_mismatch_slots):
                print("WARNING: Mismatches and total number of errors are not equal")
            
            errors_dict = {}
            errors_dict["(0, 0, 0)"] = slot_error_counts[0] / errors_sum
            errors_dict["(0, 0, 1)"] = slot_error_counts[1] / errors_sum
            errors_dict["(0, 1, 0)"] = slot_error_counts[2] / errors_sum
            errors_dict["(0, 1, 1)"] = slot_error_counts[3] / errors_sum
            errors_dict["(1, 0, 0)"] = slot_error_counts[4] / errors_sum
            errors_dict["(1, 0, 1)"] = slot_error_counts[5] / errors_sum
            errors_dict["(1, 1, 0)"] = slot_error_counts[6] / errors_sum

            errors_dict['subj_err_count'] = subj_error_count
            errors_dict['subj_err_frac'] = subj_error_count / error_count
            errors_dict['rel_err_count'] = rel_error_count
            errors_dict['rel_err_frac'] = rel_error_count / error_count
            errors_dict['obj_err_count'] = obj_error_count
            errors_dict['obj_err_frac'] = obj_error_count / error_count

            errors_dict['errors_sum'] = errors_sum

            self.error_stats[oie_system] = errors_dict

    def compute_scores(self, match_type: str = "slot"):
        """
            Compute precision, recall, f1 score and error stats.

            Args:
            -----
                match_type: str
                    Can be either "slot" or "lexical". Determines the type of matching to be used
        """
        self.compute_precision(match_type)
        self.compute_recall(match_type)
        self.compute_f1()
        if match_type == "slot":
            self.compute_error_stats()

    def print_error_analysis(self, oie_system_name: str):
        pass

    def subj_errors_count(self, oie_system_name: str):
        pass

    def rel_errors_count(self, oie_system_name: str):
        pass
    
    def obj_errors_count(self, oie_system_name: str):
        pass

    def slot_errors_count(self, oie_system_name: str):
        pass

    def print_incorrect_triples(self, oie_system_name: str):
        pass

    def print_correct_triples(self, oie_system_name: str):
        pass