# Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE

import os
import random
import logging
from enum import Enum
import torch
import numpy as np
from transformers import XLMRobertaTokenizer, XLMRobertaModel
from xlm_ra import XNLUModel, XPairModel, get_intent_labels

MODEL_CLASSES = {
    'xnlu': (XLMRobertaModel, XNLUModel),
    'base_xnlu': (XLMRobertaModel, XNLUModel),
    'mtop': (XLMRobertaModel, XNLUModel),
    'base_mtop': (XLMRobertaModel, XNLUModel),
    'm_atis': (XLMRobertaModel, XNLUModel),
    'base_m_atis': (XLMRobertaModel, XNLUModel),
    'paws_x': (XLMRobertaModel, XPairModel),
    'base_paws_x': (XLMRobertaModel, XPairModel)
}

MODEL_PATH_MAP = {
    'xnlu': '../xlm-roberta-large/',
    'base_xnlu': '../xlm-roberta-base/',
    'mtop': '../xlm-roberta-large/',
    'base_mtop': '../xlm-roberta-base/',
    'm_atis': '../xlm-roberta-large/',
    'base_m_atis': '../xlm-roberta-base/',
    'paws_x': '../xlm-roberta-large/',
    'base_paws_x': '../xlm-roberta-base/'
}


class Tasks(Enum):
    XNLU = 'xnlu'
    MTOP = 'mtop'
    PAWS_X = 'paws_x'
    M_ATIS = 'm_atis'


def load_tokenizer(model_name_or_path):
    return XLMRobertaTokenizer.from_pretrained(model_name_or_path)


def init_logger(args):
    if not os.path.exists(os.path.join(args.model_dir)):
        os.mkdir(os.path.join(args.model_dir))
    logging.basicConfig(handlers=[logging.FileHandler(os.path.join(args.model_dir, "log.log"), "a+", "utf-8"), logging.StreamHandler()],
                        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                        datefmt='%m/%d/%Y %H:%M:%S',
                        level=logging.INFO)


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda_device != "cpu" and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)


def compute_metrics(intent_predictions, intent_labels, slot_predictions, slot_labels, examples, guids, args):
    assert len(intent_predictions) == len(intent_labels) == len(slot_predictions) == len(slot_labels)
    results = {}
    results.update({"Intent_Acc": (intent_predictions == intent_labels).mean()})
    results.update(get_slot_metrics(slot_predictions, slot_labels))
    results.update(get_joint_acc(intent_predictions, intent_labels, slot_predictions, slot_labels, examples, guids, args))
    return results


def get_slot_metrics(predictions, labels):
    assert len(predictions) == len(labels)
    return {
        "Slot_Precision": precision_score(labels, predictions),
        "Slot_Recall": recall_score(labels, predictions),
        "Slot_F1": f1_score(labels, predictions)
    }


def get_joint_acc(intent_predictions, intent_labels, slot_predictions, slot_labels, examples, guids, args):
    """For the cases that intent and all the slots are correct (in one sentence)"""

    joint_result = []
    # intent_map = get_intent_labels(args)
    for s_preds, s_labels, i_pred, i_label, i in zip(slot_predictions, slot_labels, intent_predictions, intent_labels, guids):
        assert len(s_preds) == len(s_labels)
        one_sent_result = i_pred == i_label
        for p, l in zip(s_preds, s_labels):
            if p != l:
                one_sent_result = False

        # if not one_sent_result:
        #     print("NEW SENTENCE:", "".join([w.replace("▁", " ") for w in examples[i].words]))
        #     s_preds, s_labels = iter(s_preds), iter(s_labels)
        #     for word, slot_type in zip(examples[i].words, examples[i].slot_labels):
        #         if slot_type != -100:
        #             p, l = next(s_preds), next(s_labels)
        #             print(word.ljust(15), p, l, p == l)
        #         else:
        #             print(word.ljust(15), "PAD")
        #     print(intent_map[i_pred], intent_map[i_label], i_pred == i_label)
        #     print("-" * 50)
        joint_result.append(one_sent_result)

    return {"Joint_Accuracy": np.array(joint_result).mean()}

# ------------- Metrics below copied from 'seqeval' pip package -----------------


def f1_score(y_true, y_pred, suffix=False):
    """Compute the F1 score."""
    true_entities = set(get_entities(y_true, suffix))
    pred_entities = set(get_entities(y_pred, suffix))

    nb_correct = len(true_entities & pred_entities)
    nb_pred = len(pred_entities)
    nb_true = len(true_entities)

    p = nb_correct / nb_pred if nb_pred > 0 else 0
    r = nb_correct / nb_true if nb_true > 0 else 0
    score = 2 * p * r / (p + r) if p + r > 0 else 0

    return score


def precision_score(y_true, y_pred, suffix=False):
    """ Compute the precision. """
    true_entities = set(get_entities(y_true, suffix))
    pred_entities = set(get_entities(y_pred, suffix))

    nb_correct = len(true_entities & pred_entities)
    nb_pred = len(pred_entities)
    score = nb_correct / nb_pred if nb_pred > 0 else 0

    return score


def recall_score(y_true, y_pred, suffix=False):
    """ Compute the recall. """
    true_entities = set(get_entities(y_true, suffix))
    pred_entities = set(get_entities(y_pred, suffix))

    nb_correct = len(true_entities & pred_entities)
    nb_true = len(true_entities)

    score = nb_correct / nb_true if nb_true > 0 else 0

    return score


def get_entities(seq, suffix=False):
    """Gets entities from sequence."""

    # for nested list
    if any(isinstance(s, list) for s in seq):
        seq = [item for sublist in seq for item in sublist + ['O']]

    prev_tag = 'O'
    prev_type = ''
    begin_offset = 0
    chunks = []
    for i, chunk in enumerate(seq + ['O']):

        if suffix:
            tag = chunk[-1]
            type_ = chunk[:-1].rsplit('-', maxsplit=1)[0] or '_'
        else:
            tag = chunk[0]
            type_ = chunk[1:].split('-', maxsplit=1)[-1] or '_'

        if end_of_chunk(prev_tag, tag, prev_type, type_):
            chunks.append((prev_type, begin_offset, i-1))
        if start_of_chunk(prev_tag, tag, prev_type, type_):
            begin_offset = i
        prev_tag = tag
        prev_type = type_

    return chunks


def end_of_chunk(prev_tag, tag, prev_type, type_):
    """ Checks if a chunk ended between the previous and current word. """
    chunk_end = False

    if prev_tag == 'E': chunk_end = True
    if prev_tag == 'S': chunk_end = True

    if prev_tag == 'B' and tag == 'B': chunk_end = True
    if prev_tag == 'B' and tag == 'S': chunk_end = True
    if prev_tag == 'B' and tag == 'O': chunk_end = True
    if prev_tag == 'I' and tag == 'B': chunk_end = True
    if prev_tag == 'I' and tag == 'S': chunk_end = True
    if prev_tag == 'I' and tag == 'O': chunk_end = True

    if prev_tag != 'O' and prev_tag != '.' and prev_type != type_:
        chunk_end = True

    return chunk_end


def start_of_chunk(prev_tag, tag, prev_type, type_):
    """ Checks if a chunk started between the previous and current word. """
    chunk_start = False

    if tag == 'B': chunk_start = True
    if tag == 'S': chunk_start = True

    if prev_tag == 'E' and tag == 'E': chunk_start = True
    if prev_tag == 'E' and tag == 'I': chunk_start = True
    if prev_tag == 'S' and tag == 'E': chunk_start = True
    if prev_tag == 'S' and tag == 'I': chunk_start = True
    if prev_tag == 'O' and tag == 'E': chunk_start = True
    if prev_tag == 'O' and tag == 'I': chunk_start = True

    if tag != 'O' and tag != '.' and prev_type != type_:
        chunk_start = True

    return chunk_start
