import math
import numpy as np
import matplotlib.pyplot as plt

def cal_trg_coverage(trg_lengths, reorder_cols):
    precise, rough = [], []
    for i, lenth in enumerate(trg_lengths):
        tmp = [0] * lenth
        for j in range(reorder_cols.shape[1]):
            tmp[reorder_cols[i, j]] = 1
        rough.append(sum(tmp) / lenth)
        if lenth != 1:
            precise.append(sum(tmp[:-1]) / (lenth - 1))
        else:
            precise.append(1.0)
        
    return sum(precise) / len(precise), sum(rough) / len(rough), precise, rough


def dataset_statistics():
    pre_numbers = [7.23, 6.34, 3.26, 6.25, 3.24]
    ab_numbers = [2.59, 5.31, 2.59, 8.41, 2.84]
    doc_lens = [134.10, 230.13, 189.32, 245.89, 179.02]
    dataset_names = ["Inspec", "Nus", "Krapivin", "Semeval", "KP20K"]
    # 以dataset_names为横坐标，pre_numbers和ab_numbers为纵坐标，绘制叠加柱状图
    plt.bar(dataset_names, pre_numbers, label="Pre", fc="g")
    plt.bar(dataset_names, ab_numbers, bottom=pre_numbers, label="Ab", fc="r")
    plt.legend()
    plt.xlabel("Dataset")
    plt.ylabel("Number")
    plt.savefig("dataset_numbers.png")
    # legend clear
    plt.clf()

    # 以dataset_names为横坐标，doc_lens为纵坐标，绘制柱状图
    plt.bar(dataset_names, doc_lens, label="Doc", fc="g")
    plt.legend()
    plt.xlabel("Dataset")
    plt.ylabel("Length")
    plt.savefig("dataset_doc_lens.png")


class LossStatistics:
    """
    Accumulator for loss staistics. Modified from OpenNMT
    """

    def __init__(self, loss=0.0, n_tokens=0, n_batch=0, forward_time=0.0, loss_compute_time=0.0, backward_time=0.0, pre_precise_trg_coverage=0.0, pre_rough_trg_coverage=0.0, ab_precise_trg_coverage=0.0, ab_rough_trg_coverage=0.0):
        self.loss = loss
        if math.isnan(loss):
            raise ValueError("Loss is NaN")
        self.n_tokens = n_tokens
        self.n_batch = n_batch
        self.forward_time = forward_time
        self.loss_compute_time = loss_compute_time
        self.backward_time = backward_time
        self.pre_precise_trg_coverage = pre_precise_trg_coverage
        self.pre_rough_trg_coverage = pre_rough_trg_coverage
        self.ab_precise_trg_coverage = ab_precise_trg_coverage
        self.ab_rough_trg_coverage = ab_rough_trg_coverage

    def update(self, stat):
        """
        Update statistics by suming values with another `LossStatistics` object

        Args:
            stat: another statistic object
        """
        self.loss += stat.loss
        if math.isnan(stat.loss):
            raise ValueError("Loss is NaN")
        self.n_tokens += stat.n_tokens
        self.n_batch += stat.n_batch
        self.forward_time += stat.forward_time
        self.loss_compute_time += stat.loss_compute_time
        self.backward_time += stat.backward_time
        self.pre_precise_trg_coverage += stat.pre_precise_trg_coverage
        self.pre_rough_trg_coverage += stat.pre_rough_trg_coverage
        self.ab_precise_trg_coverage += stat.ab_precise_trg_coverage
        self.ab_rough_trg_coverage += stat.ab_rough_trg_coverage

    def xent(self):
        """ compute normalized cross entropy """
        assert self.n_tokens > 0, "n_tokens must be larger than 0"
        return self.loss / self.n_tokens

    def ppl(self):
        """ compute normalized perplexity """
        assert self.n_tokens > 0, "n_tokens must be larger than 0"
        return math.exp(min(self.loss / self.n_tokens, 100))

    def total_time(self):
        return self.forward_time, self.loss_compute_time, self.backward_time

    def pre_trg_coverage(self):
        return self.pre_precise_trg_coverage / self.n_batch, self.pre_rough_trg_coverage / self.n_batch
    
    def ab_trg_coverage(self):
        return self.ab_precise_trg_coverage / self.n_batch, self.ab_rough_trg_coverage / self.n_batch

    def clear(self):
        self.loss = 0.0
        self.n_tokens = 0
        self.n_batch = 0
        self.forward_time = 0.0
        self.loss_compute_time = 0.0
        self.backward_time = 0.0
        self.pre_precise_trg_coverage = 0.0
        self.pre_rough_trg_coverage = 0.0
        self.ab_precise_trg_coverage = 0.0
        self.ab_rough_trg_coverage = 0.0

if __name__ == "__main__":
    # trg_lengths = [5, 4]
    # reorder_cols = np.array([[1, 4, 1, 3, 2, 0, 4, 0, 4, 4],
    #                          [1, 3, 1, 0, 3, 3, 3, 1, 3, 0]])
    # print(cal_trg_coverage(trg_lengths, reorder_cols))
    dataset_statistics()
