import logging
import math
import os
import sys
import time

import torch
import torch.nn as nn

import pykp.utils.io as io
from inference.evaluate import evaluate_loss
from pykp.utils.label_assign import hungarian_assign, optimal_transport_assign
from pykp.utils.masked_loss import masked_cross_entropy
from utils.functions import time_since
from utils.report import export_train_and_valid_loss
from utils.statistics import LossStatistics, cal_trg_coverage

EPS = 1e-8


def train_model(model, optimizer, train_data_loader, valid_data_loader, opt):
    logging.info('======================  Start Training  =========================')

    total_batch = -1
    early_stop_flag = False

    total_train_loss_statistics = LossStatistics()
    report_train_loss_statistics = LossStatistics()
    report_train_ppl = []
    report_valid_ppl = []
    report_train_loss = []
    report_valid_loss = []
    best_valid_ppl = float('inf')
    best_valid_loss = float('inf')
    num_stop_dropping = 0
    
    model.train()
    for epoch in range(opt.start_epoch, opt.epochs + 1):
        if early_stop_flag:
            break
        for batch_i, batch in enumerate(train_data_loader):
            total_batch += 1

            batch_loss_stat = train_one_batch(batch, model, optimizer, opt)
            report_train_loss_statistics.update(batch_loss_stat)
            total_train_loss_statistics.update(batch_loss_stat)

            if total_batch % opt.report_every == 0:
                current_train_ppl = report_train_loss_statistics.ppl()
                current_train_loss = report_train_loss_statistics.xent()
                logging.info(
                        "Epoch %d; batch: %d; total batch: %d, avg training ppl: %.3f, loss: %.3f" % (epoch, batch_i,
                                                                                                      total_batch,
                                                                                                      current_train_ppl,
                                                                                                      current_train_loss))
                # if total_batch > 1:
                #     pre_current_precise_trg_coverage, pre_current_rough_trg_coverage = report_train_loss_statistics.pre_trg_coverage()
                #     ab_current_precise_trg_coverage, ab_current_rough_trg_coverage = report_train_loss_statistics.ab_trg_coverage()
                #     logging.info(
                #             "present: precise trg coverage: %.3f, rough trg coverage: %.3f" % (pre_current_precise_trg_coverage,
                #                                                                                pre_current_rough_trg_coverage))
                #     logging.info(
                #             "absent: precise trg coverage: %.3f, rough trg coverage: %.3f" % (ab_current_precise_trg_coverage,
                #                                                                               ab_current_rough_trg_coverage))

            if epoch >= opt.start_checkpoint_at:
                if (opt.checkpoint_interval == -1 and batch_i == len(train_data_loader) - 1) or \
                        (opt.checkpoint_interval > -1 and total_batch > 1 and
                         total_batch % opt.checkpoint_interval == 0):
                    valid_loss_stat = evaluate_loss(valid_data_loader, model, opt)
                    model.train()

                    current_valid_loss = valid_loss_stat.xent()
                    current_valid_ppl = valid_loss_stat.ppl()
                    logging.info("Enter check point!")

                    current_train_ppl = report_train_loss_statistics.ppl()
                    current_train_loss = report_train_loss_statistics.xent()

                    # debug
                    if math.isnan(current_valid_loss) or math.isnan(current_train_loss):
                        logging.info(
                            "NaN valid loss. Epoch: %d; batch_i: %d, total_batch: %d" % (
                                epoch, batch_i, total_batch))
                        exit()

                    if current_valid_loss < best_valid_loss:  # update the best valid loss and save the model parameters
                        logging.info("Valid loss drops")
                        sys.stdout.flush()
                        best_valid_loss = current_valid_loss
                        best_valid_ppl = current_valid_ppl
                        num_stop_dropping = 0

                        check_pt_model_path = os.path.join(opt.model_path, 'best_model.pt')
                        torch.save(  # save model parameters
                            model.state_dict(),
                            open(check_pt_model_path, 'wb')
                        )
                        logging.info('Saving checkpoint to %s' % check_pt_model_path)
                    else:
                        num_stop_dropping += 1
                        logging.info("Valid loss does not drop, patience: %d/%d" % (
                            num_stop_dropping, opt.early_stop_tolerance))

                        # decay the learning rate by a factor
                        for i, param_group in enumerate(optimizer.param_groups):
                            old_lr = float(param_group['lr'])
                            new_lr = old_lr * opt.learning_rate_decay
                            if old_lr - new_lr > EPS:
                                param_group['lr'] = new_lr

                    logging.info('Epoch: %d; batch idx: %d; total batches: %d' % (epoch, batch_i, total_batch))
                    logging.info(
                        ' * avg training ppl: %.3f; avg validation ppl: %.3f; best validation ppl: %.3f' % (
                            current_train_ppl, current_valid_ppl, best_valid_ppl))
                    logging.info(
                        ' * avg training loss: %.3f; avg validation loss: %.3f; best validation loss: %.3f' % (
                            current_train_loss, current_valid_loss, best_valid_loss))

                    report_train_ppl.append(current_train_ppl)
                    report_valid_ppl.append(current_valid_ppl)
                    report_train_loss.append(current_train_loss)
                    report_valid_loss.append(current_valid_loss)

                    if num_stop_dropping >= opt.early_stop_tolerance:
                        logging.info(
                            'Have not increased for %d check points, early stop training' % num_stop_dropping)
                        early_stop_flag = True
                        break
                    report_train_loss_statistics.clear()

    # export the training curve
    train_valid_curve_path = opt.exp_path + '/train_valid_curve'
    export_train_and_valid_loss(report_train_loss, report_valid_loss, report_train_ppl, report_valid_ppl,
                                opt.checkpoint_interval, train_valid_curve_path)


def train_one_batch(batch, model, optimizer, opt):
    src, src_lens, src_mask, src_oov, oov_lists, src_str_list, \
    trg_str_2dlist, trg, trg_oov, trg_lens, trg_mask, _ = batch
    # print("---------------- DEBUG: trg_mask ----------------")
    # print("trg_mask[0] in dataloader: ", trg_mask[0])
    # print("trg_mask.shape: ", trg_mask.shape)
    # print("trg.device: ", trg.device)
    # print("---------------- DEBUG: trg_mask ----------------")

    max_num_oov = max([len(oov) for oov in oov_lists])  # max number of oov for each batch
    batch_size = src.size(0)
    word2idx = opt.vocab['word2idx']
    target = trg_oov if opt.copy_attention else trg

    optimizer.zero_grad()
    start_time = time.time()
    if opt.fix_kp_num_len:
        y_t_init = target.new_ones(batch_size, opt.max_kp_num, 1) * word2idx[io.BOS_WORD]
        if opt.set_loss:  # K-step target assignment
            model.eval()
            with torch.no_grad():
                memory_bank = model.encoder(src, src_lens, src_mask)
                state = model.decoder.init_state(memory_bank, src_mask)
                control_embed = model.decoder.forward_seg(state)

                input_tokens = src.new_zeros(batch_size, opt.max_kp_num, opt.assign_steps + 1)
                decoder_dists = []
                input_tokens[:, :, 0] = word2idx[io.BOS_WORD]
                for t in range(1, opt.assign_steps + 1):
                    decoder_inputs = input_tokens[:, :, :t]
                    decoder_inputs = decoder_inputs.masked_fill(decoder_inputs.gt(opt.vocab_size - 1),
                                                                word2idx[io.UNK_WORD])

                    decoder_dist, _ = model.decoder(decoder_inputs, state, src_oov, max_num_oov, control_embed)
                    input_tokens[:, :, t] = decoder_dist.argmax(-1)
                    decoder_dists.append(decoder_dist.reshape(batch_size, opt.max_kp_num, 1, -1))

                decoder_dists = torch.cat(decoder_dists, -2)

                if opt.seperate_pre_ab:
                    mid_idx = opt.max_kp_num // 2
                    
                    if opt.use_optimal_transport:
                        
                        background = torch.tensor([word2idx[io.NULL_WORD]] + 
                                                  [word2idx[io.PAD_WORD]] * (opt.max_kp_len - 1)).to(opt.device)
                        bg_mask = torch.tensor([1] + [0] * (opt.max_kp_len - 1)).to(opt.device)
                        pre_targets, pre_trg_masks, ab_targets, ab_trg_masks = [], [], [], [] # 一个batch里面每组数据对应的target数量是不一样的，用tensor无法对齐，所以用list
                        pre_has_null, ab_has_null = [False] * batch_size, [False] * batch_size
                        for b in range(batch_size):
                            pre_target, pre_trg_mask, ab_target, ab_trg_mask = [], [], [], []

                            for t in range(mid_idx):
                                if any(target[b, t] != background): # 去掉null kp，只保留真实的target kp
                                    pre_target.append(list(target[b, t]))
                                    pre_trg_mask.append(list(trg_mask[b, t]))
                            
                            if len(pre_target) != opt.max_kp_num // 2:
                                pre_has_null[b] = True
                                pre_target.append(list(background))  # 补上一个null kp作为学习的目标
                                pre_trg_mask.append(list(bg_mask))
                                               
                            pre_targets.append(torch.tensor(pre_target).to(opt.device))
                            pre_trg_masks.append(torch.tensor(pre_trg_mask).to(opt.device))

                            for t in range(mid_idx, opt.max_kp_num):
                                if any(target[b, t] != background):
                                    ab_target.append(list(target[b, t]))
                                    ab_trg_mask.append(list(trg_mask[b, t]))
                            
                            if len(ab_target) != opt.max_kp_num // 2:
                                ab_has_null[b] = True
                                ab_target.append(list(background))  # 补上一个null kp作为学习的目标
                                ab_trg_mask.append(list(bg_mask))

                            ab_targets.append(torch.tensor(ab_target).to(opt.device))
                            ab_trg_masks.append(torch.tensor(ab_trg_mask).to(opt.device))

                        '''
                        print("pre_targets.length: \n", (len(pre_targets)))
                        print("pre_targets[0].shape: \n", pre_targets[0].shape)
                        print("pre_targets: \n", pre_targets)'''
                        _, pre_reorder_cols, pre_matching_scale = optimal_transport_assign(decoder_dists[:, :mid_idx],
                                                                                           pre_targets, 
                                                                                           assign_steps=opt.assign_steps,
                                                                                           has_null=pre_has_null,
                                                                                           k_strategy=opt.k_strategy,
                                                                                           top_candidates=opt.top_candidates,
                                                                                           temperature=opt.assign_temperature)
                        '''
                        print("pre_targets.lengths: \n", [pre_targets[i].shape[0] for i in range(batch_size)])
                        print("pre_reorder_cols.shape: \n", pre_reorder_cols.shape)
                        print("pre_reorder_cols: \n", pre_reorder_cols)
                        print("present target coverage: \n", cal_trg_coverage([pre_targets[i].shape[0] for i in range(batch_size)], pre_reorder_cols))'''
                        pre_precise_trg_coverage, pre_rough_trg_coverage = cal_trg_coverage([pre_targets[i].shape[0] for i in range(batch_size)], pre_reorder_cols)
                        '''
                        trg_lengths = [pre_targets[i].shape[0] for i in range(batch_size)]
                        for b in range(batch_size):
                           print("[%d]: " % b)
                           for i in range(pre_targets[b].shape[0]):
                               a = set()
                               for j in range(pre_reorder_cols.shape[1]):
                                   if pre_reorder_cols[b, j] == i:
                                       a.add(j)
                               print("\t[%d]: " % i, a, end="")
                               print("     size: %d" % len(a))
                        '''
                        _, ab_reorder_cols, ab_matching_scale = optimal_transport_assign(decoder_dists[:, mid_idx:], 
                                                                                         ab_targets, 
                                                                                         assign_steps=opt.assign_steps,
                                                                                         has_null=ab_has_null,
                                                                                         k_strategy=opt.k_strategy,
                                                                                         top_candidates=opt.top_candidates,
                                                                                         temperature=opt.assign_temperature)
                        ab_precise_trg_coverage, ab_rough_trg_coverage = cal_trg_coverage([ab_targets[i].shape[0] for i in range(batch_size)], ab_reorder_cols)
                        
                        new_pre_targets, new_pre_trg_masks, new_ab_targets, new_ab_trg_masks = [], [], [], []

                        for b in range(batch_size):
                            new_pre_targets.append(pre_targets[b][pre_reorder_cols[b]])
                            new_pre_trg_masks.append(pre_trg_masks[b][pre_reorder_cols[b]])
                            new_ab_targets.append(ab_targets[b][ab_reorder_cols[b]])
                            new_ab_trg_masks.append(ab_trg_masks[b][ab_reorder_cols[b]])

                        #print([x.shape for x in new_pre_targets])
                        target[:, :mid_idx] = torch.stack(new_pre_targets, axis=0)
                        #print("target[:, :mid_idx].shape: \n", target[:, :mid_idx].shape)
                        #print("target[:, :mid_idx]: \n", target[:, :mid_idx])
                        trg_mask[:, :mid_idx] = torch.stack(new_pre_trg_masks, axis=0)

                        target[:, mid_idx:] = torch.stack(new_ab_targets, axis=0)
                        trg_mask[:, mid_idx:] = torch.stack(new_ab_trg_masks, axis=0)
                        
                    else:
                        pre_reorder_index = hungarian_assign(decoder_dists[:, :mid_idx],
                                                            target[:, :mid_idx, :opt.assign_steps],
                                                            ignore_indices=[word2idx[io.NULL_WORD],
                                                                            word2idx[io.PAD_WORD]])
                        target[:, :mid_idx] = target[:, :mid_idx][pre_reorder_index]
                        trg_mask[:, :mid_idx] = trg_mask[:, :mid_idx][pre_reorder_index]

                        ab_reorder_index = hungarian_assign(decoder_dists[:, mid_idx:],
                                                            target[:, mid_idx:, :opt.assign_steps],
                                                            ignore_indices=[word2idx[io.NULL_WORD],
                                                                            word2idx[io.PAD_WORD]])
                        target[:, mid_idx:] = target[:, mid_idx:][ab_reorder_index]
                        trg_mask[:, mid_idx:] = trg_mask[:, mid_idx:][ab_reorder_index]
                else:
                    if opt.use_optimal_transport:
                        reorder_index = optimal_transport_assign(decoder_dists, target[:, :, :opt.assign_steps],
                                                                 [word2idx[io.NULL_WORD],
                                                                  word2idx[io.PAD_WORD]])
                        target = target[reorder_index]
                        trg_mask = trg_mask[reorder_index]
                    else:
                        reorder_index = hungarian_assign(decoder_dists, target[:, :, :opt.assign_steps],
                                                         [word2idx[io.NULL_WORD],
                                                          word2idx[io.PAD_WORD]])
                        target = target[reorder_index]
                        trg_mask = trg_mask[reorder_index]

            model.train()

        memory_bank = model.encoder(src, src_lens, src_mask)
        state = model.decoder.init_state(memory_bank, src_mask)
        control_embed = model.decoder.forward_seg(state)

        input_tgt = torch.cat([y_t_init, target[:, :, :-1]], dim=-1)
        input_tgt = input_tgt.masked_fill(input_tgt.gt(opt.vocab_size - 1), word2idx[io.UNK_WORD])
        decoder_dist, attention_dist = model.decoder(input_tgt, state, src_oov, max_num_oov, control_embed)

        # print("decoder_dist.shape: ", decoder_dist.shape)
        # print("decoder_dist[:1, :12, :5]: \n", decoder_dist[:1, :12, :5])
        # print("decoder_dists.shape: ", decoder_dists.shape)
        # print("decoder_dists[:1, :2, :, :5]: \n", decoder_dists[:1, :2, :, :5])

    else:
        y_t_init = trg.new_ones(batch_size, 1) * word2idx[io.BOS_WORD]  # [batch_size, 1]
        input_tgt = torch.cat([y_t_init, trg[:, :-1]], dim=-1)
        memory_bank = model.encoder(src, src_lens, src_mask)
        state = model.decoder.init_state(memory_bank, src_mask)
        decoder_dist, attention_dist = model.decoder(input_tgt, state, src_oov, max_num_oov)

    forward_time = time_since(start_time)
    start_time = time.time()
    if opt.fix_kp_num_len:
        if opt.seperate_pre_ab:
            mid_idx = opt.max_kp_num // 2

            pre_matching_scale = torch.cat([pre_matching_scale.unsqueeze(2)] * opt.max_kp_len, dim=2)\
                                    .reshape(batch_size, -1) if opt.loss_matching_scale else None
            pre_loss = masked_cross_entropy(
                decoder_dist.reshape(batch_size, opt.max_kp_num, opt.max_kp_len, -1)[:, :mid_idx]\
                    .reshape(batch_size, opt.max_kp_len * mid_idx, -1),
                target[:, :mid_idx].reshape(batch_size, -1),
                trg_mask[:, :mid_idx].reshape(batch_size, -1),
                loss_scales=[opt.loss_scale_pre],
                scale_indices=[word2idx[io.NULL_WORD]],
                matching_scale=pre_matching_scale)
            
            ab_matching_scale = torch.cat([ab_matching_scale.unsqueeze(2)] * opt.max_kp_len, dim=2)\
                                    .reshape(batch_size, -1) if opt.loss_matching_scale else None
            ab_loss = masked_cross_entropy(
                decoder_dist.reshape(batch_size, opt.max_kp_num, opt.max_kp_len, -1)[:, mid_idx:]
                    .reshape(batch_size, opt.max_kp_len * mid_idx, -1),
                target[:, mid_idx:].reshape(batch_size, -1),
                trg_mask[:, mid_idx:].reshape(batch_size, -1),
                loss_scales=[opt.loss_scale_ab],
                scale_indices=[word2idx[io.NULL_WORD]],
                matching_scale=ab_matching_scale)
            
            loss = pre_loss + ab_loss
        else:
            loss = masked_cross_entropy(decoder_dist, target.reshape(batch_size, -1), trg_mask.reshape(batch_size, -1),
                                        loss_scales=[opt.loss_scale], scale_indices=[word2idx[io.NULL_WORD]])
    else:
        loss = masked_cross_entropy(decoder_dist, target, trg_mask)
    loss_compute_time = time_since(start_time)

    total_trg_tokens = trg_mask.sum().item()
    total_trg_sents = src.size(0)
    if opt.loss_normalization == "tokens":  # use number of target tokens to normalize the loss
        normalization = total_trg_tokens
    elif opt.loss_normalization == 'batches':  # use batch_size to normalize the loss
        normalization = total_trg_sents
    else:
        raise ValueError('The type of loss normalization is invalid.')
    assert normalization > 0, 'normalization should be a positive number'

    start_time = time.time()
    total_loss = loss.div(normalization)

    total_loss.backward()
    backward_time = time_since(start_time)

    if opt.max_grad_norm > 0:
        nn.utils.clip_grad_norm_(model.parameters(), opt.max_grad_norm)

    optimizer.step()
    stat = LossStatistics(loss.item(), total_trg_tokens, n_batch=1, forward_time=forward_time,
                          loss_compute_time=loss_compute_time, backward_time=backward_time, 
                          pre_precise_trg_coverage=pre_precise_trg_coverage, 
                          pre_rough_trg_coverage=pre_rough_trg_coverage, 
                          ab_precise_trg_coverage=ab_precise_trg_coverage, 
                          ab_rough_trg_coverage=ab_rough_trg_coverage)
    return stat
