#
 #     MILIE: Modular & Iterative Multilingual Open Information Extraction
 #
 #
 #
 #     Authors: Deleted for purposes of anonymity
 #
 #     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 #
 # The software and its source code contain valuable trade secrets and shall be maintained in
 # confidence and treated as confidential information. The software may only be used for
 # evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 # license agreement or nondisclosure agreement with the proprietor of the software.
 # Any unauthorized publication, transfer to third parties, or duplication of the object or
 # source code---either totally or in part---is strictly prohibited.
 #
 #     Copyright (c) 2021 Proprietor: Deleted for purposes of anonymity
 #     All Rights Reserved.
 #
 # THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR
 # IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY
 # AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT
 # DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION.
 #
 # NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 # IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE
 # LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 # FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 # OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 # ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 # TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 # THE POSSIBILITY OF SUCH DAMAGES.
 #
 # For purposes of anonymity, the identity of the proprietor is not given herewith.
 # The identity of the proprietor will be given once the review of the
 # conference submission is completed.
 #
 # THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 #

import codecs
import json
import numpy as np

import matplotlib



def write_list_to_file(list_to_write, file_to_write):
    """
    Write a list to a file.

    :param list_to_write: the list to be written to a file
    :param file_to_write: the file to write to
    :return: 0 on success
    """
    if list_to_write is None:
        return 1
    assert isinstance(list_to_write, list), list_to_write
    if list_to_write.count(None) == len(list_to_write):
        return 1
    with codecs.open(file_to_write, 'w', encoding='utf8') as file:
        for line in list_to_write:
            if line is None:
                continue
            print(line, file=file)
    return 0


def write_json_to_file(json_object, file_to_write):
    """
    Write a json object to a file.

    :param json: the json object to write
    :param file_to_write: the location to write to
    :return: 0 on success
    """
    if json_object is None:
        return 1
    with open(file_to_write, "w") as writer:
        writer.write(json.dumps(json_object, indent=4) + "\n")
    return 0


def write_list_to_jsonl_file(list_to_write, file_to_write):
    """
    Write a list to a file.

    :param list_to_write: the list to be written to a file
    :param file_to_write: the file to write to
    :return: 0 on success
    """
    if list_to_write is None:
        return 1
    with codecs.open(file_to_write, 'w', encoding='utf8') as file:
        for line in list_to_write:
            file.write(json.dumps(line) + "\n")
    return 0


def write_conll(list_to_write, file_to_write):
    """
    Write in conll format.

    for example:
    list_to_write = [
        [('I', 'O'), ('live', 'O'), ('in', 'O'), ('Heidel', 'B'), ('##berg', 'I')],
        [('Heidel', 'B'), ('##berg', 'I'), ('is', 'O'), ('in', 'O'), ('Germany', 'B')]
    ]

    :param list_to_write: list of list
    :param file_to_write: str, path to save
    :return: 0 on success
    """
    if list_to_write is None:
        return 1
    assert isinstance(list_to_write, list), list_to_write
    if list_to_write.count(None) == len(list_to_write):
        return 1
    with codecs.open(file_to_write, 'w', encoding='utf8') as file:
        for sent in list_to_write:
            if sent is None:
                continue
            for token in sent:
                file.write('\t'.join(token) + '\n')
            file.write('\n')
    return 0


def read_lines_in_list(file_to_read):
    """
    Reads a file into a list.

    :param file_to_read: the location of the file to be read
    :return: a list where each entry corresponds to a line in the file
    """
    read_list = []
    with codecs.open(file_to_read, 'r', encoding='utf8') as file:
        for line in file:
            read_list.append(line.rstrip('\n'))
    return read_list


def read_json(json_to_read):
    """
    Read a json file

    :param json_to_read: the json object to read
    :return: the json object
    """
    with open(json_to_read, "r") as reader:
        json_object = json.loads(reader)
    return json_object


def read_conll(conll_to_read):
    """
    read file in conll format.

    for example:
    conll = [
        [('I', "O"), ('live', 'O'), ('in', 'O'), ('Heidel', 'B'), ('##berg', 'I')],
        [('Heidel', 'B'), ('##berg', 'I'), ('is', 'O'), ('in', 'O'), ('Germany', 'B')]
    ]

    :param conll_to_read: str, path to file
    :return: list of tuple
    """
    conll = []
    sent = []
    with codecs.open(conll_to_read, 'r', encoding='utf8') as file:
        for line in file.readlines():
            line = line.rstrip('\n')
            if len(line) > 0:
                sent.append(tuple(line.split('\t')))
            elif len(line) == 0 and len(sent) > 0:
                conll.append(sent)
                sent = []

    return conll


def sublist_start_index(search_for_this, find_here):
    """
    Given a list a, checks if the entire list is contain in sequential order in b.
    If so, return the first index in b where a starts.

    :param search_for_this: the list to be found in find_here
    :param find_here: the list searched for search_for_this
    :return: If a in b, then start index in b, else None
    """

    if len(search_for_this)<=0 or (len(search_for_this) > len(find_here)) or (search_for_this[0] not in find_here):
        return None
    for i in range(find_here.index(search_for_this[0]), len(find_here) - len(search_for_this) + 1):
        if find_here[i:i + len(search_for_this)] == search_for_this:
            return i
    return None


def compute_softmax(scores, alpha=1.0):
    """
    Computes softmax probaility over raw logits

    :param scores: a numpy array with logits
    :param alpha: temperature parameter, values >1.0 approach argmax, values <1.0 approach uniform
    :return: a numpy array with probability scores
    """
    scores = scores * float(alpha)
    scores = scores - np.max(scores)
    scores = np.exp(scores)
    probs = scores / np.sum(scores)

    return probs


#https://github.com/huggingface/transformers/blob/f88c104d8f79e78a98c8ce6c1f4a78db73142eab/transformers/tokenization_utils.py#L882
def truncate_sequences(part_a, part_b, num_tokens_to_remove=0, truncation_strategy='longest_first', truncate_end=True):
    """Truncates a sequence pair in place to the maximum length.
       truncation_strategy: string selected in the following options:

            - 'longest_first' (default) Iteratively reduce the inputs sequence until the input is under max_length
              starting from the longest one at each token (when there is a pair of input sequences).
              Overflowing tokens only contains overflow from the first sequence.
            - 'only_first': Only truncate the first sequence. raise an error if the first sequence is shorter or equal to than num_tokens_to_remove.
            - 'only_second': Only truncate the second sequence
            - 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
    """
    assert isinstance(part_a, list) and isinstance(part_b, list)

    if num_tokens_to_remove <= 0:
        return part_a, part_b, []

    if truncation_strategy == 'longest_first':
        for _ in range(num_tokens_to_remove):
            if len(part_a) > len(part_b):
                part_a = part_a[:-1] if truncate_end else part_a[1:]
            else:
                part_b = part_b[:-1] if truncate_end else part_b[1:]
    elif truncation_strategy == 'only_first':
        assert len(part_a) > num_tokens_to_remove
        part_a = part_a[:-num_tokens_to_remove] if truncate_end else part_a[num_tokens_to_remove:]
    elif truncation_strategy == 'only_second':
        assert len(part_b) > num_tokens_to_remove
        part_b = part_b[:-num_tokens_to_remove] if truncate_end else part_b[num_tokens_to_remove:]
    elif truncation_strategy == 'do_not_truncate':
        raise ValueError("Input sequence are too long for max_length. Please select a truncation strategy.")
    else:
        raise ValueError("Truncation_strategy should be selected in ['longest_first', 'only_first', 'only_second', 'do_not_truncate']")
    return part_a, part_b


#https://github.com/joeynmt/joeynmt/blob/master/joeynmt/plotting.py
def plot_attention(scores, column_labels, row_labels, output_path=None):
    """
    visualize self-attention heatmap per head.

    :param scores: (np.ndarray) attention scores
    :param column_labels: labels for columns (e.g. target tokens)
    :param row_labels: labels for rows (e.g. source tokens)
    :param output_path: path to save to
    :return: pyplot figure
    """
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    from matplotlib import rcParams
    rcParams['font.family'] = "sans-serif"
    rcParams['font.sans-serif'] = ['IPAexGothic', 'IPAPGothic', 'IPAGothic']
    rcParams['font.weight'] = "regular"
    x_sent_len = len(column_labels)
    y_sent_len = len(row_labels)

    num_heads, sequence_length, sequence_length = scores.shape
    num_rows = 3
    num_cols = 4
    assert num_heads == num_rows*num_cols, scores.shape

    fig, axs = plt.subplots(nrows=num_rows, ncols=num_cols, figsize=(num_cols*3, num_rows*3))
    for i, row in enumerate(axs):
        for j, col in enumerate(row):
            n = i*len(row) + j
            nth_scores = scores[n, :y_sent_len, :x_sent_len]
            #print('scores.shape', nth_scores.shape)
            # check that cut off part didn't have any attention
            assert np.sum(scores[n, y_sent_len:, x_sent_len:]) == 0

            col.imshow(nth_scores, cmap='viridis', aspect='equal', origin='upper', vmin=0., vmax=1.)
            col.set_title('head {}'.format(n+1))

    for ax in axs.flat:
        # automatic label size
        labelsize = max(min(25 * (10 / max(x_sent_len, y_sent_len)), 10.5), 5)
        # font config
        rcParams['xtick.labelsize'] = labelsize
        rcParams['ytick.labelsize'] = labelsize
        # labels
        ax.set_xticklabels(column_labels, minor=False, rotation="vertical")
        ax.set_yticklabels(row_labels, minor=False)

        #ax.xaxis.tick_top()
        ax.set_xticks(np.arange(nth_scores.shape[1]) + 0, minor=False)
        ax.set_yticks(np.arange(nth_scores.shape[0]) + 0, minor=False)
        #ax.set(xlabel='x-label', ylabel='y-label')

    # Hide x labels and tick labels for top plots and y ticks for right plots.
    for ax in axs.flat:
        ax.label_outer()
    #plt.subplots_adjust(wspace=-0.5)
    plt.tight_layout()

    if output_path is not None:
        plt.savefig(output_path)

    plt.close()

    return fig
