#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
File: source/utils/metrics.py
"""
import os
import re
import subprocess
import tempfile
import numpy as np
import torch
import torch.nn.functional as F
from collections import Counter


def accuracy(logits, targets, padding_idx=None):
    """
    logits: (batch_size, max_len, vocab_size)
    targets: (batch_size, max_len)
    """
    _, preds = logits.max(dim=2)
    trues = (preds == targets).float()
    if padding_idx is not None:
        weights = targets.ne(padding_idx).float()
        acc = (weights * trues).sum(dim=1) / weights.sum(dim=1)
    else:
        acc = trues.mean(dim=1)
    acc = acc.mean()
    return acc


def attn_accuracy(logits, targets):
    """
    logits: (batch_size, vocab_size)
    targets: (batch_size)
    """
    _, preds = logits.squeeze(1).max(dim=-1)
    trues = (preds == targets).float()
    acc = trues.mean()
    return acc


def perplexity(logits, targets, weight=None, padding_idx=None):
    """
    logits: (batch_size, max_len, vocab_size)
    targets: (batch_size, max_len)
    """
    batch_size = logits.size(0)
    if weight is None and padding_idx is not None:
        weight = torch.ones(logits.size(-1)).cuda()
        weight[padding_idx] = 0
    nll = F.nll_loss(input=logits.view(-1, logits.size(-1)),
                     target=targets.contiguous().view(-1),
                     weight=weight,
                     reduction='none')
    nll = nll.view(batch_size, -1).sum(dim=1)
    #print('nllm sh', nll)
    #print('nll m', nll)
    if padding_idx is not None:
        word_cnt = targets.ne(padding_idx).float().sum(dim=1)
        tot_word_cnt = targets.ne(padding_idx).float().sum()
        #print('wc', word_cnt)
        nll = nll / word_cnt
    
    ppl = nll.exp()
    ppl = ppl.mean()
    #print('ppl', ppl)
    return ppl, tot_word_cnt


def compute_prf(gold_entity_list, pred_sent, global_entity_list, kb_plain):
    """
    compute entity precision/recall/F1 score
    """
    if len(kb_plain) > 0:
        local_kb_word = [k.strip().split()[-1] if len(k) > 0 else "" for k in kb_plain]
    else:
        local_kb_word = []
    
    TP, FP, FN = 0, 0, 0
    if len(gold_entity_list) != 0:
        count = 1
        for g in gold_entity_list:
            if g in pred_sent:
                TP += 1
            else:
                FN += 1
        for p in set(pred_sent):
            if p in global_entity_list or p in local_kb_word:
                if p not in gold_entity_list:
                    FP += 1
        F1 = compute_f1(TP, FP, FN)
    else:
        F1, count = 0, 0
    return TP, FP, FN, F1, count


def compute_f1(tp_count, fp_count, fn_count):
    """
    compute f1 score
    """
    precision = tp_count / float(tp_count + fp_count) if (tp_count + fp_count) != 0 else 0
    recall = tp_count / float(tp_count + fn_count) if (tp_count + fn_count) != 0 else 0
    f1_score = 2 * precision * recall / float(precision + recall) if (precision + recall) != 0 else 0
    return f1_score


def distinct(seqs):
    """
    distinct
    """
    intra_dist1, intra_dist2 = [], []
    unigrams_all, bigrams_all = Counter(), Counter()
    for seq in seqs:
        unigrams = Counter(seq)
        bigrams = Counter(zip(seq, seq[1:]))
        intra_dist1.append((len(unigrams)+1e-12) / (len(seq)+1e-5))
        intra_dist2.append((len(bigrams)+1e-12) / (max(0, len(seq)-1)+1e-5))

        unigrams_all.update(unigrams)
        bigrams_all.update(bigrams)

    inter_dist1 = (len(unigrams_all)+1e-12) / (sum(unigrams_all.values())+1e-5)
    inter_dist2 = (len(bigrams_all)+1e-12) / (sum(bigrams_all.values())+1e-5)
    intra_dist1 = np.average(intra_dist1)
    intra_dist2 = np.average(intra_dist2)
    return intra_dist1, intra_dist2, inter_dist1, inter_dist2


def calc_distinct_ngram(seqs, ngram):
    """
    calc_distinct_ngram
    """
    ngram_total = 0.0
    ngram_distinct_count = 0.0
    pred_dict = {}
    for seq in seqs:
        get_dict(seq, ngram, pred_dict)
    for key, freq in pred_dict.items():
        ngram_total += freq
        ngram_distinct_count += 1
    return ngram_distinct_count / ngram_total


def get_dict(tokens, ngram, gdict=None):
    """
    get_dict
    """
    token_dict = {}
    if gdict is not None:
        token_dict = gdict
    tlen = len(tokens)
    for i in range(0, tlen - ngram + 1):
        ngram_token = "".join(tokens[i:(i + ngram)])
        if token_dict.get(ngram_token) is not None:
            token_dict[ngram_token] += 1
        else:
            token_dict[ngram_token] = 1
    return token_dict


def moses_multi_bleu(hypotheses, references, lowercase=False):
    """Calculate the bleu score using the MOSES ulti-bleu.perl script.
    Args:
    hypotheses: A numpy array of strings where each string is a single example.
    references: A numpy array of strings where each string is a single example.
    lowercase: If true, pass the "-lc" flag to the multi-bleu script
    Returns:
    The BLEU score as a float32 value.
    """

    if np.size(hypotheses) == 0:
        return np.float32(0.0)

    # Set MOSES multi-bleu script path
    metrics_dir = os.path.dirname(os.path.realpath(__file__))
    multi_bleu_path = os.path.join(metrics_dir, "multi-bleu.perl")

    # Dump hypotheses and references to tempfiles
    hypothesis_file = tempfile.NamedTemporaryFile()
    hypothesis_file.write("\n".join(hypotheses).encode("utf-8"))
    hypothesis_file.write(b"\n")
    hypothesis_file.flush()
    reference_file = tempfile.NamedTemporaryFile()
    reference_file.write("\n".join(references).encode("utf-8"))
    reference_file.write(b"\n")
    reference_file.flush()

    # Calculate BLEU using multi-bleu script
    bleu_score = 0.0
    with open(hypothesis_file.name, "r") as read_pred:
        bleu_cmd = [multi_bleu_path]
        if lowercase:
            bleu_cmd += ["-lc"]
        bleu_cmd += [reference_file.name]
        try:
            bleu_out = subprocess.check_output(bleu_cmd, stdin=read_pred, stderr=subprocess.STDOUT)
            bleu_out = bleu_out.decode("utf-8")
            bleu_score = re.search(r"BLEU = (.+?),", bleu_out).group(1)
            bleu_score = float(bleu_score)
        except subprocess.CalledProcessError as error:
            if error.output is not None:
                print("multi-bleu.perl script returned non-zero exit code")
                print(error.output)
                bleu_score = np.float32(0.0)

    # Close temp files
    hypothesis_file.close()
    reference_file.close()

    return bleu_score
