import logging
import os
import time
import json
import torch
import pykp.utils.io as io
from pykp.utils.masked_loss import masked_cross_entropy
from utils.statistics import LossStatistics
from utils.string_helper import *
from utils.functions import time_since
from pykp.utils.label_assign import hungarian_assign, optimal_transport_assign
import matplotlib.pyplot as plt
from tqdm import tqdm

def evaluate_loss(data_loader, model, opt):
    model.eval()
    evaluation_loss_sum = 0.0
    total_trg_tokens = 0
    n_batch = 0
    loss_compute_time_total = 0.0
    forward_time_total = 0.0

    with torch.no_grad():
        for batch_i, batch in enumerate(data_loader):
            src, src_lens, src_mask, src_oov, oov_lists, src_str_list, \
            trg_str_2dlist, trg, trg_oov, trg_lens, trg_mask, _, = batch

            max_num_oov = max([len(oov) for oov in oov_lists])  # max number of oov for each batch
            batch_size = src.size(0)
            n_batch += batch_size
            word2idx = opt.vocab['word2idx']
            target = trg_oov if opt.copy_attention else trg

            start_time = time.time()
            if opt.fix_kp_num_len:
                memory_bank = model.encoder(src, src_lens, src_mask)
                state = model.decoder.init_state(memory_bank, src_mask)
                control_embed = model.decoder.forward_seg(state)

                y_t_init = target.new_ones(batch_size, opt.max_kp_num, 1) * word2idx[io.BOS_WORD]
                if opt.set_loss:  # reassign target
                    input_tokens = src.new_zeros(batch_size, opt.max_kp_num, opt.assign_steps + 1)
                    decoder_dists = []
                    input_tokens[:, :, 0] = word2idx[io.BOS_WORD]
                    for t in range(1, opt.assign_steps + 1):
                        decoder_inputs = input_tokens[:, :, :t]
                        decoder_inputs = decoder_inputs.masked_fill(decoder_inputs.gt(opt.vocab_size - 1),
                                                                    word2idx[io.UNK_WORD])

                        decoder_dist, _ = model.decoder(decoder_inputs, state, src_oov, max_num_oov, control_embed)
                        input_tokens[:, :, t] = decoder_dist.argmax(-1)
                        decoder_dists.append(decoder_dist.reshape(batch_size, opt.max_kp_num, 1, -1))
                    decoder_dists = torch.cat(decoder_dists, -2)

                    if opt.seperate_pre_ab:
                        mid_idx = opt.max_kp_num // 2

                        if opt.use_optimal_transport:
                            background = torch.tensor([word2idx[io.NULL_WORD]] + 
                                                      [word2idx[io.PAD_WORD]] * (opt.max_kp_len - 1)).to(opt.device)
                            bg_mask = torch.tensor([1] + [0] * (opt.max_kp_len - 1)).to(opt.device)

                            pre_targets, pre_trg_masks, ab_targets, ab_trg_masks = [], [], [], []
                            pre_has_null, ab_has_null = [False] * batch_size, [False] * batch_size

                            for b in range(batch_size):
                                pre_target, pre_trg_mask, ab_target, ab_trg_mask = [], [], [], []

                                for t in range(mid_idx):
                                    if any(target[b, t] != background): # 去掉null kp，只保留真实的target kp
                                        pre_target.append(list(target[b, t]))
                                        pre_trg_mask.append(list(trg_mask[b, t]))
                                
                                if len(pre_target) != opt.max_kp_num // 2:
                                    pre_has_null[b] = True
                                    pre_target.append(list(background))  # 补上一个null kp作为学习的目标
                                    pre_trg_mask.append(list(bg_mask))
    
                                pre_targets.append(torch.tensor(pre_target).to(opt.device))
                                pre_trg_masks.append(torch.tensor(pre_trg_mask).to(opt.device))

                                for t in range(mid_idx, opt.max_kp_num):
                                    if any(target[b, t] != background):
                                        ab_target.append(list(target[b, t]))
                                        ab_trg_mask.append(list(trg_mask[b, t]))
                                
                                if len(ab_target) != opt.max_kp_num // 2:
                                    ab_has_null[b] = True
                                    ab_target.append(list(background))  # 补上一个null kp作为学习的目标
                                    ab_trg_mask.append(list(bg_mask))

                                ab_targets.append(torch.tensor(ab_target).to(opt.device))
                                ab_trg_masks.append(torch.tensor(ab_trg_mask).to(opt.device))

                            _, pre_reorder_cols, pre_matching_scale = optimal_transport_assign(decoder_dists[:, :mid_idx],
                                                                                               pre_targets, 
                                                                                               assign_steps=opt.assign_steps,
                                                                                               has_null=pre_has_null,
                                                                                               k_strategy=opt.k_strategy,
                                                                                               top_candidates=opt.top_candidates,
                                                                                               temperature=opt.assign_temperature)
                            _, ab_reorder_cols, ab_matching_scale = optimal_transport_assign(decoder_dists[:, mid_idx:],
                                                                                             ab_targets,
                                                                                             assign_steps=opt.assign_steps,
                                                                                             has_null=ab_has_null,
                                                                                             k_strategy=opt.k_strategy,
                                                                                             top_candidates=opt.top_candidates,
                                                                                             temperature=opt.assign_temperature)
                            
                            new_pre_targets, new_pre_trg_masks, new_ab_targets, new_ab_trg_masks = [], [], [], []
                            for b in range(batch_size):
                                new_pre_targets.append(pre_targets[b][pre_reorder_cols[b]])
                                new_pre_trg_masks.append(pre_trg_masks[b][pre_reorder_cols[b]])
                                new_ab_targets.append(ab_targets[b][ab_reorder_cols[b]])
                                new_ab_trg_masks.append(ab_trg_masks[b][ab_reorder_cols[b]])

                            target[:, :mid_idx] = torch.stack(new_pre_targets, axis=0)
                            trg_mask[:, :mid_idx] = torch.stack(new_pre_trg_masks, axis=0)

                            target[:, mid_idx:] = torch.stack(new_ab_targets, axis=0)
                            trg_mask[:, mid_idx:] = torch.stack(new_ab_trg_masks, axis=0)
                            
                        else:
                            pre_reorder_index = hungarian_assign(decoder_dists[:, :mid_idx],
                                                                target[:, :mid_idx, :opt.assign_steps],
                                                                ignore_indices=[word2idx[io.NULL_WORD],
                                                                                word2idx[io.PAD_WORD]])
                            target[:, :mid_idx] = target[:, :mid_idx][pre_reorder_index]
                            trg_mask[:, :mid_idx] = trg_mask[:, :mid_idx][pre_reorder_index]

                            ab_reorder_index = hungarian_assign(decoder_dists[:, mid_idx:],
                                                                target[:, mid_idx:, :opt.assign_steps],
                                                                ignore_indices=[word2idx[io.NULL_WORD],
                                                                                word2idx[io.PAD_WORD]])
                            target[:, mid_idx:] = target[:, mid_idx:][ab_reorder_index]
                            trg_mask[:, mid_idx:] = trg_mask[:, mid_idx:][ab_reorder_index]
                    else:
                        reorder_index = hungarian_assign(decoder_dists, target[:, :, :opt.assign_steps],
                                                         [word2idx[io.NULL_WORD],
                                                          word2idx[io.PAD_WORD]])
                        target = target[reorder_index]
                        trg_mask = trg_mask[reorder_index]

                state = model.decoder.init_state(memory_bank, src_mask)  # refresh the state
                input_tgt = torch.cat([y_t_init, target[:, :, :-1]], dim=-1)
                input_tgt = input_tgt.masked_fill(input_tgt.gt(opt.vocab_size - 1), word2idx[io.UNK_WORD])
                decoder_dist, attention_dist = model.decoder(input_tgt, state, src_oov, max_num_oov, control_embed)

            else:
                y_t_init = trg.new_ones(batch_size, 1) * word2idx[io.BOS_WORD]  # [batch_size, 1]
                input_tgt = torch.cat([y_t_init, trg[:, :-1]], dim=-1)
                memory_bank = model.encoder(src, src_lens, src_mask)
                state = model.decoder.init_state(memory_bank, src_mask)
                decoder_dist, attention_dist = model.decoder(input_tgt, state, src_oov, max_num_oov)

            forward_time = time_since(start_time)
            forward_time_total += forward_time

            start_time = time.time()
            if opt.fix_kp_num_len:
                if opt.seperate_pre_ab:
                    mid_idx = opt.max_kp_num // 2

                    pre_matching_scale = torch.cat([pre_matching_scale.unsqueeze(2)] * opt.max_kp_len, dim=2)\
                                            .reshape(batch_size, -1) if opt.loss_matching_scale else None
                    pre_loss = masked_cross_entropy(
                        decoder_dist.reshape(batch_size, opt.max_kp_num, opt.max_kp_len, -1)[:, :mid_idx] \
                            .reshape(batch_size, opt.max_kp_len * mid_idx, -1),
                        target[:, :mid_idx].reshape(batch_size, -1),
                        trg_mask[:, :mid_idx].reshape(batch_size, -1),
                        loss_scales=[opt.loss_scale_pre],
                        scale_indices=[word2idx[io.NULL_WORD]],
                        matching_scale=pre_matching_scale)
                    
                    ab_matching_scale = torch.cat([ab_matching_scale.unsqueeze(2)] * opt.max_kp_len, dim=2)\
                                            .reshape(batch_size, -1) if opt.loss_matching_scale else None
                    ab_loss = masked_cross_entropy(
                        decoder_dist.reshape(batch_size, opt.max_kp_num, opt.max_kp_len, -1)[:, mid_idx:]
                            .reshape(batch_size, opt.max_kp_len * mid_idx, -1),
                        target[:, mid_idx:].reshape(batch_size, -1),
                        trg_mask[:, mid_idx:].reshape(batch_size, -1),
                        loss_scales=[opt.loss_scale_ab],
                        scale_indices=[word2idx[io.NULL_WORD]],
                        matching_scale=pre_matching_scale)
                    loss = pre_loss + ab_loss
                else:
                    loss = masked_cross_entropy(decoder_dist, target.reshape(batch_size, -1),
                                                trg_mask.reshape(batch_size, -1),
                                                loss_scales=[opt.loss_scale], scale_indices=[word2idx[io.NULL_WORD]])
            else:
                loss = masked_cross_entropy(decoder_dist, target, trg_mask)
            loss_compute_time = time_since(start_time)
            loss_compute_time_total += loss_compute_time

            evaluation_loss_sum += loss.item()
            total_trg_tokens += trg_mask.sum().item()

    eval_loss_stat = LossStatistics(evaluation_loss_sum, total_trg_tokens, n_batch, forward_time=forward_time_total,
                                    loss_compute_time=loss_compute_time_total)
    return eval_loss_stat


def evaluate_greedy_generator(data_loader, generator, opt):
    pred_output_file = open(os.path.join(opt.pred_path, "predictions.txt"), "w")
    interval = 1000
    # save all results into json file
    if not opt.load:
        # all_result_file = open(os.path.join(opt.pred_path, "all_results.json"), "w")
        with torch.no_grad():
            word2idx = opt.vocab['word2idx']
            idx2word = opt.vocab['idx2word']
            start_time = time.time()
            pre_null_ratio, ab_null_ratio = [], []
            for batch_i, batch in enumerate(data_loader):


                # use corresponding slide
                # if slide=0, use the 0 ~ 1*200000 instances; if slide=1, use the 1*200000 ~ 2*200000 instances
                if opt.slide != None:
                    if batch_i * opt.batch_size >= (opt.slide + 1) * 150000:
                        break
                    elif batch_i * opt.batch_size < opt.slide * 150000:
                        continue
                # if batch_i < 15744:
                #     continue
                # print('batch_i: ', batch_i)


                if (batch_i + 1) % interval == 0:
                    logging.info("Batch %d: Time for running beam search on %d batches : %.1f" % (
                        batch_i + 1, interval, time_since(start_time)))
                    start_time = time.time()

                src, src_lens, src_mask, src_oov, oov_lists, src_str_list, \
                trg_str_2dlist, trg, trg_oov, trg_lens, trg_mask, original_idx_list = batch

                if opt.fix_kp_num_len:
                    if opt.beam_size == None:
                        n_best_result = generator.inference(src, src_lens, src_oov, src_mask, oov_lists, word2idx)
                    else:
                        n_best_result = generator.inference_with_beam_search(src, src_lens, src_oov, src_mask, oov_lists, word2idx)
                    
                    # # 根据decoder_scores排序
                    # for index, (predictions, scores, atts) in enumerate(zip(n_best_result['predictions'], n_best_result['decoder_scores'], n_best_result['attention'])):
                    #     predictions, scores, atts = zip(*sorted(zip(predictions, scores, atts), key=lambda p: p[1], reverse=True))
                    #     n_best_result['predictions'][index] = predictions
                    #     n_best_result['decoder_scores'][index] = scores
                    #     n_best_result['attention'][index] = atts
                    # save_k不应该在这里起作用，因为null的置信度往往很高排在前面
                    pred_list = preprocess_n_best_result(n_best_result, idx2word, opt.vocab_size, oov_lists,
                                                        eos_idx=-1,  # to keep all the keyphrases rather than only the first one
                                                        unk_idx=word2idx[io.UNK_WORD],
                                                        replace_unk=opt.replace_unk,
                                                        src_str_list=src_str_list)
                    '''
                    # calculate null ratio of predictions
                    pre_null_cnt, ab_null_cnt = 0, 0
                    for i in range(10):
                        pre_null_cnt += (pred_list[0][0][i * 6] == '<null>')
                    for i in range(10):
                        ab_null_cnt += (pred_list[0][0][60 + i * 6] == '<null>')
                    pre_null_ratio.append(pre_null_cnt / 10)
                    ab_null_ratio.append(ab_null_cnt / 10)
                    print("{}%... ".format((batch_i + 1) / len(data_loader) * 100))
                    print("pre_null_ratio: ", sum(pre_null_ratio) / (batch_i + 1))
                    print("ab_null_ratio: ", sum(ab_null_ratio) / (batch_i + 1))
                    '''

                    # recover the original order in the dataset
                    seq_pairs = sorted(zip(original_idx_list, src_str_list, trg_str_2dlist, pred_list, oov_lists,
                                        n_best_result['decoder_scores']),
                                    key=lambda p: p[0])
                    original_idx_list, src_str_list, trg_str_2dlist, pred_list, oov_lists, decoder_scores = zip(*seq_pairs)
                    # if not opt.load:
                    #     # convert decoder_scores into list
                    #     temp_json = {"predictions": pred_list, "scores": [[float(e.cpu().numpy()) for e in batch_e] for batch_e in decoder_scores],}
                    #     all_result_file.write(json.dumps(temp_json) + '\n')

                    # Process every src in the batch
                    for src_str, trg_str_list, pred, oov, decoder_score in zip(src_str_list, trg_str_2dlist, pred_list,
                                                                            oov_lists, decoder_scores):
                        if opt.beam_size == None:
                            all_keyphrase_list = split_word_list_from_set(pred[-1], decoder_score[-1].cpu().numpy(),
                                                                        opt.max_kp_len,
                                                                        opt.max_kp_num, io.EOS_WORD, io.NULL_WORD)
                        else:
                            all_keyphrase_list = split_word_list_from_set_new(pred, [s.cpu().numpy() for s in decoder_score], 
                                                                    opt.max_kp_len,
                                                                    opt.max_kp_num, io.EOS_WORD, io.NULL_WORD)

                        # output the predicted keyphrases to a file
                        write_example_kp(pred_output_file, all_keyphrase_list)
                else:
                    n_best_result = generator.beam_search(src, src_lens, src_oov, src_mask, oov_lists, word2idx)
                    pred_list = preprocess_n_best_result(n_best_result, idx2word, opt.vocab_size, oov_lists,
                                                        word2idx[io.EOS_WORD],
                                                        word2idx[io.UNK_WORD],
                                                        opt.replace_unk, src_str_list)

                    # recover the original order in the dataset
                    seq_pairs = sorted(zip(original_idx_list, src_str_list, trg_str_2dlist, pred_list, oov_lists),
                                    key=lambda p: p[0])
                    original_idx_list, src_str_list, trg_str_2dlist, pred_list, oov_lists = zip(*seq_pairs)

                    # Process every src in the batch
                    for src_str, trg_str_list, pred, oov in zip(src_str_list, trg_str_2dlist, pred_list, oov_lists):
                        # src_str: a list of words; 
                        # trg_str: a list of keyphrases, each keyphrase is a list of words
                        # pred_seq_list: a list of sequence objects, sorted by scores
                        # oov: a list of oov words
                        # all_keyphrase_list: a list of word list contains all the keyphrases \
                        # in the top max_n sequences decoded by beam search
                        all_keyphrase_list = []
                        for word_list in pred:
                            all_keyphrase_list += split_word_list_by_delimiter(word_list, io.SEP_WORD)

                        # output the predicted keyphrases to a file
                        write_example_kp(pred_output_file, all_keyphrase_list)

        pred_output_file.close()
        # if not opt.load:
        #     all_result_file.close()
    else:
        all_result_file = open(os.path.join(opt.pred_path, "all_results.json"), "r")
        logging.info("Loading all_results.json from %s" % os.path.join(opt.pred_path, "all_results.json"))
        for line_i, line in tqdm(enumerate(all_result_file), total=len(data_loader)):
            temp_json = json.loads(line.strip())
            pred_list = temp_json['predictions']
            decoder_scores = temp_json['scores']
            # Process every src in the batch
            for pred, decoder_score in zip(pred_list, decoder_scores):
                all_keyphrase_list = split_word_list_from_set_new(pred, decoder_score,
                                                                opt.max_kp_len,
                                                                opt.max_kp_num, io.EOS_WORD, io.NULL_WORD,
                                                                opt.score_threshold, opt.prob_range)
                # output the predicted keyphrases to a file
                write_example_kp(pred_output_file, all_keyphrase_list)
                del pred, decoder_score, all_keyphrase_list

        pred_output_file.close()
        all_result_file.close()


def write_example_kp(out_file, kp_list):
    pred_print_out = ''
    for word_list_i, word_list in enumerate(kp_list):
        if word_list_i < len(kp_list) - 1:
            pred_print_out += '%s;' % ' '.join(word_list)
        else:
            pred_print_out += '%s' % ' '.join(word_list)
    pred_print_out += '\n'
    out_file.write(pred_print_out)


def preprocess_n_best_result(n_best_result, idx2word, vocab_size, oov_lists, eos_idx, unk_idx, replace_unk,
                             src_str_list):
    predictions = n_best_result['predictions']
    attention = n_best_result['attention']
    pred_list = []  # a list of dict, with len = batch_size
    for pred_n_best, attn_n_best, oov, src_word_list in zip(predictions, attention, oov_lists, src_str_list):
        sentences_n_best = []
        for pred, attn in zip(pred_n_best, attn_n_best):
            sentence = prediction_to_sentence(pred, idx2word, vocab_size, oov, eos_idx, unk_idx, replace_unk,
                                              src_word_list, attn)
            sentences_n_best.append(sentence)
        # a list of list of word, with len [n_best, out_seq_len], does not include tbe final <EOS>
        pred_list.append(sentences_n_best)
    return pred_list
