import torch, time
from models.model_builder import build_model
from utils.data_loader import get_DataLoader, get_vocab, get_citation_DataLoader
from tqdm import tqdm
import torch.optim as optim
from utils.loss import build_loss_compute
from utils.data_utils import make_tgt, make_src, parse_data_path_cfg, parse_model_path_cfg, load_pretrained_embedding

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

def train_generation(model, train_iterator, test_iterator, optimizer, loss_compute, epoch, src_max_len,
                     tgt_max_len, abs_max_sent_num, node_max_neighbor, gpu, is_citation_func):
    kl_loss = torch.nn.KLDivLoss(reduction="batchmean")
    criterion = None
    if is_citation_func:
        normedWeights = [0.70, 0.95, 0.37, 0.98]
        if gpu:
            normedWeights = torch.FloatTensor(normedWeights).cuda()
        else:
            normedWeights = torch.FloatTensor(normedWeights)
        criterion = torch.nn.CrossEntropyLoss(weight=normedWeights)
    best_test_loss = float('inf')
    # best_test_loss = 27.557
    for step_ in range(epoch):
        start_time = time.time()
        train_loss = train_step(model, train_iterator, optimizer, loss_compute, kl_loss, criterion, src_max_len, tgt_max_len,
                                abs_max_sent_num, node_max_neighbor, gpu)
        test_loss = test_step(model, test_iterator, loss_compute, kl_loss, criterion, src_max_len, tgt_max_len, abs_max_sent_num,
                              node_max_neighbor, gpu)
        end_time = time.time()

        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        if test_loss < best_test_loss:
            best_test_loss = test_loss

            print("Saving checkpoint %s_step_%d.pt" % (model_opt["base_path"], step_))
            checkpoint_path = '%s_step_%d.pt' % (model_opt["base_path"], step_)
            torch.save(model.state_dict(), checkpoint_path)
        print(f'Epoch: {step_ + 1:02} | Time: {epoch_mins}m {epoch_secs}s')
        print(f'\tTrain Loss: {train_loss:7.3f}')
        print(f'\tTest Loss: {test_loss:7.3f}')

def train_step(model, iterator, optimizer, loss_compute, kl_loss, criterion, src_max_len, tgt_max_len, abs_max_sent_num,
               node_max_neighbor, gpu):
    model.train()
    batch_loss = 0.
    for di, data in tqdm(enumerate(iterator)):
        optimizer.zero_grad()
        abstract_src_map = make_src(data, 'abs_src_map', gpu)
        context_src_map = make_src(data, 'context_src_map', gpu)
        abs_align = make_tgt(data, 'abs_align', gpu)
        context_align = make_tgt(data, 'context_align', gpu)
        outputs, attns, tgt, salience_scores, tgt_salience, citation_func_pred, citation_function\
            = model(data, src_max_len, tgt_max_len, abs_max_sent_num, node_max_neighbor)
        loss_data = {}
        loss_data["tgt"] = tgt
        loss_data["abs_align"] = abs_align
        loss_data["context_align"] = context_align
        loss_data["abs_src_map"] = abstract_src_map
        loss_data["context_src_map"] = context_src_map
        loss_data["batch_size"] = len(data)
        del data
        loss = loss_compute(loss_data, outputs, attns, normalization=loss_data["batch_size"],
                            shard_size=0, trunc_start=0)
        loss = loss + 1*kl_loss(salience_scores.log(), tgt_salience)
        if criterion is not None:
            loss = loss + 1*criterion(citation_func_pred, citation_function)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 2)
        optimizer.step()
        batch_loss += float(loss.item())
        del loss, outputs, attns, tgt, salience_scores, tgt_salience, citation_func_pred, citation_function
    return batch_loss

def test_step(model, iterator, loss_compute, kl_loss, criterion, src_max_len, tgt_max_len, abs_max_sent_num,
              node_max_neighbor, gpu):
    model.eval()
    batch_loss = 0.
    for di, data in tqdm(enumerate(iterator)):
        optimizer.zero_grad()
        abstract_src_map = make_src(data, 'abs_src_map', gpu)
        context_src_map = make_src(data, 'context_src_map', gpu)
        abs_align = make_tgt(data, 'abs_align', gpu)
        context_align = make_tgt(data, 'context_align', gpu)
        outputs, attns, tgt, salience_scores, tgt_salience, citation_func_pred, citation_function\
            = model(data, src_max_len, tgt_max_len, abs_max_sent_num, node_max_neighbor)
        loss_data = {}
        loss_data["tgt"] = tgt
        loss_data["abs_align"] = abs_align
        loss_data["context_align"] = context_align
        loss_data["abs_src_map"] = abstract_src_map
        loss_data["context_src_map"] = context_src_map
        loss_data["batch_size"] = len(data)
        del data
        loss = loss_compute(
            loss_data,
            outputs,
            attns,
            normalization=loss_data["batch_size"],
            shard_size=0,
            trunc_start=0)
        loss = loss + 0.9*kl_loss(salience_scores.log(), tgt_salience)
        if criterion is not None:
            loss = loss + 0.5*criterion(citation_func_pred, citation_function)
        batch_loss += float(loss.item())
        del loss, outputs, attns, tgt, salience_scores, tgt_salience, citation_func_pred, citation_function
    return batch_loss

def train_citation(model, iterator, optimizer, epoch, node_max_neighbor):
    BC_loss = torch.nn.BCEWithLogitsLoss()
    model.train()
    best_train_loss = float('inf')
    for step_ in range(epoch):
        start_time = time.time()
        batch_loss = 0.
        for di, data in tqdm(enumerate(iterator), disable=False):
            optimizer.zero_grad()
            output, link_hat, link_gt = model(data, node_max_neighbor, is_train=True)
            loss = BC_loss(link_hat, link_gt)
            loss.backward()
            optimizer.step()
            batch_loss += float(loss.item())
            del loss, link_hat, link_gt

        end_time = time.time()

        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        if batch_loss < best_train_loss:
            best_train_loss = batch_loss

            print("Saving checkpoint %s_step_%d.pt" % (model_opt["citation_path"], step_))
            checkpoint_path = '%s_step_%d.pt' % (model_opt["citation_path"], step_)
            torch.save(model.state_dict(), checkpoint_path)
        print(f'Epoch: {step_ + 1:02} | Time: {epoch_mins}m {epoch_secs}s')
        print(f'\tTrain Loss: {batch_loss:7.3f}')

def train_salience(model, train_iterator, test_iterator, optimizer, epoch):
    kl_loss = torch.nn.KLDivLoss(reduction="batchmean")
    best_loss = float('inf')
    for step_ in range(epoch):
        start_time = time.time()
        train_loss = train_salience_step(model, train_iterator, optimizer, kl_loss)
        test_loss = test_salience_step(model, test_iterator, kl_loss)
        end_time = time.time()

        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        if test_loss < best_loss:
            best_loss = test_loss

            print("Saving checkpoint %s_step_%d.pt" % (model_opt["salience_path"], step_))
            checkpoint_path = '%s_step_%d.pt' % (model_opt["salience_path"], step_)
            torch.save(model.state_dict(), checkpoint_path)
        print(f'Epoch: {step_ + 1:02} | Time: {epoch_mins}m {epoch_secs}s')
        print(f'\tTrain Loss: {train_loss:7.3f}')
        print(f'\tTest Loss: {test_loss:7.3f}')

def test_salience_step(model, iterator, kl_loss):
    model.eval()
    batch_loss = 0.
    for di, data in tqdm(enumerate(iterator), disable=False):
        optimizer.zero_grad()
        salience_hat, salience_gt = model(data, src_max_len, abs_max_sent_num)
        loss = kl_loss(salience_hat.log(), salience_gt)
        batch_loss += float(loss.item())
        del loss, salience_hat, salience_gt
    return batch_loss

def train_salience_step(model, iterator, optimizer, kl_loss):
    model.train()
    batch_loss = 0.
    for di, data in tqdm(enumerate(iterator), disable=False):
        optimizer.zero_grad()
        salience_hat, salience_gt = model(data, src_max_len, abs_max_sent_num)
        loss = kl_loss(salience_hat.log(), salience_gt)
        loss.backward()
        optimizer.step()
        batch_loss += float(loss.item())
        del loss, salience_hat, salience_gt
    return batch_loss

def build_optimizer(model, model_opt):
    optimizer = optim.Adagrad(model.parameters(), lr=float(model_opt["learning_rate"]))
    # optimizer = optim.Adam(model.parameters(), lr=float(model_opt["learning_rate"]),
    #                        weight_decay=float(model_opt["weight_decay"]))
    for group in optimizer.param_groups:
        for p in group['params']:
            optimizer.state[p]['sum'] = optimizer \
                .state[p]['sum'].fill_(float(model_opt["adagrad_accum"]))

    return optimizer

path = parse_data_path_cfg()
model_opt = parse_model_path_cfg()

gpu = True if model_opt["gpu"] == 'True' else False
is_citation_func = True if model_opt["citation_function"] == "True" else False
load_word_embedding = True if model_opt["load_word_embedding"] == "True" else False
shard_size = 0
src_max_len = 200
tgt_max_len = 52
abs_max_sent_num = 49
node_max_neighbor = 736
batch_size = int(model_opt["batch_size"])
train_mode = int(model_opt["train_mode"])
if train_mode == 2:
    iterator = get_citation_DataLoader(path, batch_size)
    model = build_model(model_opt, None, gpu, None, checkpoint=True)
    optimizer = build_optimizer(model, model_opt)
    train_citation(model, iterator, optimizer, int(model_opt["epoch"]), node_max_neighbor)
else:
    word2index = get_vocab(path)
    word_embeddings = None
    if load_word_embedding:
        word_embeddings = load_pretrained_embedding(path, word2index)
    train_iter, test_iter = get_DataLoader(path, word2index, batch_size, True)
    model = build_model(model_opt, len(word2index), gpu, word_embeddings, checkpoint=True)
    optimizer = build_optimizer(model, model_opt)
    if train_mode == 1:
        train_salience(model, train_iter, test_iter, optimizer, int(model_opt["epoch"]))
    elif train_mode == 3:
        loss_compute = build_loss_compute(model, word2index, gpu)
        train_generation(model, train_iter, test_iter, optimizer, loss_compute, int(model_opt["epoch"]),
                         src_max_len, tgt_max_len, abs_max_sent_num, node_max_neighbor, gpu, is_citation_func)
