 #
 #     BenchIE: A Framework for Multi-Faceted Fact-Based Open Information Extraction Evaluation
 #
 #        File:  utils.py
 #
 #     Authors: Deleted for purposes of anonymity
 #
 #     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 #
 # The software and its source code contain valuable trade secrets and shall be maintained in
 # confidence and treated as confidential information. The software may only be used for
 # evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 # license agreement or nondisclosure agreement with the proprietor of the software.
 # Any unauthorized publication, transfer to third parties, or duplication of the object or
 # source code---either totally or in part---is strictly prohibited.
 #
 #     Copyright (c) 2021 Proprietor: Deleted for purposes of anonymity
 #     All Rights Reserved.
 #
 # THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR
 # IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY
 # AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT
 # DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION.
 #
 # NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 # IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE
 # LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 # FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 # OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 # ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 # TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 # THE POSSIBILITY OF SUCH DAMAGES.
 #
 # For purposes of anonymity, the identity of the proprietor is not given herewith.
 # The identity of the proprietor will be given once the review of the
 # conference submission is completed.
 #
 # THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 #

import pdb

import numpy as np

def triple2string(triple:list) -> str:
    """
        Write a triple to string with tab separator between the slots

        Args
        ----
            triple: list
                List of strings [subj, rel, obj]
        
        Returns
        -------
            triple_str: str
                String "subj \t rel \t obj" (without spaces)
    """
    return triple[0] + "\t" + triple[1] + "\t" + triple[2]

def write_to_carb_format(triples: list, sentences: dict, write_filename: str):
    """
        Given a list of triples and a dictionary of sentences, write the extractions in CaRB format.

        Args
        ----
            triples: list
                List of triples. Each triple is a list of 4 elements: [sent_id, subj, predicate, object]
            sentences: dict
                Dictionary of sentences. Key: sent_id; value: sentence (str)
            write_filename: str
                Filename where the final result is written
    """
    write_file = open(write_filename, 'w')
    
    for triple in triples:
        sent = sentences[triple[0]]
        subj = triple[1]
        rel = triple[2]
        obj = triple[3]
        write_file.write(sent + "\t1.00\t" + rel + "\t" + subj + "\t" + obj + "\n")
    
    write_file.close()

def print_correct_triples(extractions, golden_annotations, input_sentences):
    current_sent_ind = 0
    for triple in extractions:
        sent_id = int(triple[0])
        
        if current_sent_ind != sent_id - 1:
            current_sent_ind = sent_id - 1
            print("\n" + input_sentences[current_sent_ind])
            
        if sent_id not in golden_annotations:
            continue
        
        if is_true_positive(triple, golden_annotations):
            print(triple)

def is_true_positive(triple, golden_annotations):
    """
        Args
        ----
            triple: list 
                The object 'triple' is in the following format [sent_id, subj, rel, obj]
            golden_annotations: dict
                key: sentence id (sent_id from triple)
                value: list of lists. Each list represents a triple synset (i.e., list of triples having the same meaning).
        
        Returns
        -------
            boolean:
                True: if the triple is found in the golden annotations
                False: otherwise
    """
    subj = triple[1]
    rel = triple[2]
    obj = triple[3]
    
    sent_id = triple[0]
    for triple_synset in golden_annotations[sent_id]:
        for tr in triple_synset:
            if subj == tr[0] and rel == tr[1] and obj == tr[2]:
                return True
    return False

def print_incorrect_triples(extractions, golden_annotations, input_sentences):
    current_sent_ind = 0
    for triple in extractions:
        sent_id = int(triple[0])
        pdb.set_trace()
        if current_sent_ind != sent_id - 1:
            current_sent_ind = sent_id - 1
            print("\n" + input_sentences[current_sent_ind])
            
        if sent_id not in golden_annotations:
            continue
        
        if not is_true_positive(triple, golden_annotations):
            print(triple)

def get_max_slot_matches(openie_extractions, golden_annotations):
    max_match_slots = np.zeros((1, 3))
    max_match_slots.fill(-1)
    
    for tup in openie_extractions:
        sent_id = tup[0]
        triple = tup[1:]
        oie_slot_matches_stats = get_slot_matches_stats(triple, sent_id, golden_annotations)
        if oie_slot_matches_stats['max_match_slots'].size == 0:
            continue
        max_match_slots = np.append(max_match_slots, [oie_slot_matches_stats['max_match_slots'][0].tolist()], axis=0)
        
    max_match_slots = max_match_slots[1:,:]

def get_max_slot_mismatches(openie_extractions, golden_annotations):
    max_mismatch_slots = np.zeros((1, 3))
    max_mismatch_slots.fill(-1)
    
    for tup in openie_extractions:
        sent_id = tup[0]
        triple = tup[1:]
        oie_slot_matches_stats = get_slot_matches_stats(triple, sent_id, golden_annotations)
        if oie_slot_matches_stats['max_match_slots'].size == 0:
            continue
        if oie_slot_matches_stats['max_match'] == 3:
            continue
        #pdb.set_trace()
        max_mismatch_slots = np.append(max_mismatch_slots, [oie_slot_matches_stats['max_match_slots'][0].tolist()], axis=0)
    
    max_mismatch_slots = max_mismatch_slots[1:,:]
    
    return max_mismatch_slots

def get_slot_matches(triple, fact_synset):
    """
        For an input triple and a fact synset, return a list which is the amount of matches of the triple
        with each golden triple from the fact synset.
    """
    matches_count = np.zeros(len(fact_synset)).astype(int)
    matches_slots = np.zeros((len(fact_synset), 3)).astype(int)
    
    for i in range(len(fact_synset)):
        t_gold = fact_synset[i]
        counter = 0
        if t_gold[0] == triple[0]:
            counter += 1
            matches_slots[i, 0] = 1
        if t_gold[1] == triple[1]:
            counter += 1
            matches_slots[i, 1] = 1
        if t_gold[2] == triple[2]:
            counter += 1
            matches_slots[i, 2] = 1
        matches_count[i] = counter
    
    return matches_count, matches_slots

def get_slot_matches_stats(triple, sent_id, golden_extractions):
    # Count the per-slot matches of the triple with every triple in every fact synset
    matches_counts = []
    matches_max = np.zeros(len(golden_extractions[sent_id])).astype(int)
    matches = []
    for i in range(len(golden_extractions[sent_id])):
        fact_synset = golden_extractions[sent_id][i]
        slot_matches_count, matches_slots = get_slot_matches(triple, fact_synset)
        matches_counts.append(slot_matches_count)
        matches.append(matches_slots)
        matches_max[i] = max(slot_matches_count)
    
    matches_dict = {}
    matches_dict['matches_counts'] = matches_counts
    matches_dict['matches_max'] = matches_max
    matches_dict['matches_slots'] = matches
    if len(matches_max) > 0:
        matches_dict['max_match'] = int(np.max(matches_max))
    else:
        #print("HERE")
        matches_dict['max_match'] = - 1
    
    # Get max mismatches
    max_match_indices = np.argwhere(matches_dict['matches_max'] == matches_dict['max_match']).squeeze(axis=1)
    #pdb.set_trace()
    max_mismatches = []
    duplicates_flag = False
    
    for ind in max_match_indices:
        mismatches_sum = np.sum(matches_dict['matches_slots'][ind], axis=1)
        max_mismatches_indices = np.argwhere(mismatches_sum == matches_dict['max_match']).squeeze(axis=1).tolist() 
        #if isinstance(max_mismatches_indices, int):
        #    max_mismatches_indices = [max_mismatches_indices]
        max_mismatches_slots = np.take(matches_dict['matches_slots'][ind], max_mismatches_indices, axis=0)
        for slots in max_mismatches_slots:
            max_mismatches.append(slots.tolist())
            if slots.tolist() not in max_mismatches:
                print("DUPLICATE!")
                duplicates_flag = True
                #pdb.set_trace()
    max_mismatches = np.array(max_mismatches)
    matches_dict['max_match_slots'] = max_mismatches
    if duplicates_flag:
        print(max_mismatches)
        print(max_mismatches.shape)
        print("------------------")
    
    return matches_dict

def get_slot_matches(triple, fact_synset):
    """
        For an input triple and a fact synset, return a list which is the amount of matches of the triple
        with each golden triple from the fact synset.
    """
    matches_count = np.zeros(len(fact_synset)).astype(int)
    matches_slots = np.zeros((len(fact_synset), 3)).astype(int)
    
    for i in range(len(fact_synset)):
        t_gold = fact_synset[i]
        counter = 0
        if t_gold[0] == triple[0]:
            counter += 1
            matches_slots[i, 0] = 1
        if t_gold[1] == triple[1]:
            counter += 1
            matches_slots[i, 1] = 1
        if t_gold[2] == triple[2]:
            counter += 1
            matches_slots[i, 2] = 1
        matches_count[i] = counter
    
    return matches_count, matches_slots

def error_analysis(openie_extractions, golden_annotations):
    max_mismatch_slots = get_max_slot_mismatches(openie_extractions, golden_annotations)
    error_count = len(max_mismatch_slots)
    
    # Per-slot mismatches
    slots_match_count = np.sum(max_mismatch_slots, axis=0).astype(int)
    subj_error_count = error_count - slots_match_count[0]
    rel_error_count = error_count - slots_match_count[1]
    obj_error_count = error_count - slots_match_count[2]
    
    # Slot-pair mismatches
    # index 0: errors in every slot      [0, 0, 0]
    # index 1: errors in subj and rel    [0, 0, 1]
    # index 2: errors in subj and object [0, 1, 0]
    # index 3: errors in subj only       [0, 1, 1]
    # index 4: errors in rel and object  [1, 0, 0]
    # index 5: errors in rel only        [1, 0, 1]
    # index 6: errors in obj only        [1, 1, 0]
    # index 7: no errors                 [1, 1, 1] -> should always be zero
    
    slot_error_counts = np.zeros(8).astype(int)
    for slots in max_mismatch_slots:
        if slots[0] == 0 and slots[1] == 0 and slots[2] == 0:
            slot_error_counts[0] += 1
        elif slots[0] == 0 and slots[1] == 0 and slots[2] == 1:
            slot_error_counts[1] += 1
        elif slots[0] == 0 and slots[1] == 1 and slots[2] == 0:
            slot_error_counts[2] += 1
        elif slots[0] == 0 and slots[1] == 1 and slots[2] == 1:
            slot_error_counts[3] += 1
        elif slots[0] == 1 and slots[1] == 0 and slots[2] == 0:
            slot_error_counts[4] += 1
        elif slots[0] == 1 and slots[1] == 0 and slots[2] == 1:
            slot_error_counts[5] += 1
        elif slots[0] == 1 and slots[1] == 1 and slots[2] == 0:
            slot_error_counts[6] += 1
        elif slots[0] == 1 and slots[1] == 1 and slots[2] == 1:
            slot_error_counts[7] += 1
    errors_sum = np.sum(slot_error_counts)
    
    # Tests
    if slot_error_counts[7] > 0:
        print("WARNING: There is a triple conidered an error which is correct")
    if errors_sum != len(max_mismatch_slots):
        print("WARNING: Mismatches and total number of errors are not equal")
    
    print("---------------")
    print("Errors in subject:\t" + str(subj_error_count/error_count))
    print("Errors in relation:\t" + str(rel_error_count/error_count))
    print("Errors in object:\t" + str(obj_error_count/error_count))
    print("Errors in every slot      [0, 0, 0]:\t" + str(slot_error_counts[0]/errors_sum))
    print("Errors in subj and rel    [0, 0, 1]:\t" + str(slot_error_counts[1]/errors_sum))
    print("Errors in subj and object [0, 1, 0]:\t" + str(slot_error_counts[2]/errors_sum))
    print("Errors in subj only       [0, 1, 1]:\t" + str(slot_error_counts[3]/errors_sum))
    print("Errors in rel and object  [1, 0, 0]:\t" + str(slot_error_counts[4]/errors_sum))
    print("Errors in rel only        [1, 0, 1]:\t" + str(slot_error_counts[5]/errors_sum))
    print("Errors in obj only        [1, 1, 0]:\t" + str(slot_error_counts[6]/errors_sum))
    
    errors_dict = {}
    errors_dict["[0, 0, 0]"] = slot_error_counts[0] / errors_sum
    errors_dict["[0, 0, 1]"] = slot_error_counts[1] / errors_sum
    errors_dict["[0, 1, 0]"] = slot_error_counts[2] / errors_sum
    errors_dict["[0, 1, 1]"] = slot_error_counts[3] / errors_sum
    errors_dict["[1, 0, 0]"] = slot_error_counts[4] / errors_sum
    errors_dict["[1, 0, 1]"] = slot_error_counts[5] / errors_sum
    errors_dict["[1, 1, 0]"] = slot_error_counts[6] / errors_sum
    
    return errors_dict