import os
import time
import torch
import numpy as np
from torch import nn, optim
from torch.nn.utils import clip_grad_norm_
from torch.optim import Adam

from evaluator import Evaluator
from utils import word_drop, tensor2text


def get_lengths(tokens, eos_idx):
    lengths = torch.cumsum(tokens == eos_idx, 1)
    lengths = (lengths == 0).long().sum(-1)
    lengths = lengths + 1  # +1 for <eos> token
    return lengths


def batch_preprocess(batch, pad_idx, eos_idx, reverse=False):
    batch_pos, batch_neg = batch
    diff = batch_pos.size(1) - batch_neg.size(1)
    if diff < 0:
        pad = torch.full_like(batch_neg[:, :-diff], pad_idx)
        batch_pos = torch.cat((batch_pos, pad), 1)
    elif diff > 0:
        pad = torch.full_like(batch_pos[:, :diff], pad_idx)
        batch_neg = torch.cat((batch_neg, pad), 1)

    pos_styles = torch.ones_like(batch_pos[:, 0])
    neg_styles = torch.zeros_like(batch_neg[:, 0])

    if reverse:
        batch_pos, batch_neg = batch_neg, batch_pos
        pos_styles, neg_styles = neg_styles, pos_styles

    tokens = torch.cat((batch_pos, batch_neg), 0)
    lengths = get_lengths(tokens, eos_idx)
    styles = torch.cat((pos_styles, neg_styles), 0)

    return tokens, lengths, styles


def f_step(config, vocab, model_F, model_D, optimizer_F, batch, temperature, drop_decay):
    pad_idx = vocab.stoi['<pad>']
    eos_idx = vocab.stoi['<eos>']

    ce = nn.CrossEntropyLoss()
    nll = nn.NLLLoss(reduction='none')

    inp_tokens, inp_lengths, labels = batch_preprocess(batch, pad_idx, eos_idx)  # (128,16) (128,) (128,)
    inp_embbedings = model_F.embeddings(inp_tokens)

    raw_styles, rev_styles, attn_hidden, src_mask = model_F.AaCL.AaDL.style_attention(inp_embbedings, inp_lengths)

    batch_size = inp_tokens.size(0)
    token_mask = (inp_tokens != pad_idx).float()  # 1->word 0->none

    optimizer_F.zero_grad()

    noise_inp_tokens = word_drop(inp_embbedings, inp_lengths, config.inp_drop_prob * drop_decay, vocab)
    noise_inp_lengths = get_lengths(noise_inp_tokens, eos_idx)

    slf_z = model_F(noise_inp_tokens, inp_tokens, inp_lengths, raw_styles, reverse=False, temperature=temperature)

    slf_log_probs = model_F(slf_z, inp_tokens, inp_lengths, raw_styles, reverse=True, temperature=temperature)

    slf_rec_loss = ce(slf_log_probs.transpose(1, 2), inp_tokens) * token_mask
    slf_rec_loss = slf_rec_loss.sum() / batch_size
    slf_rec_loss *= config.slf_factor
    slf_rec_loss.backward()

    # cycle consistency loss
    gen_z = model_F(inp_tokens, inp_tokens, inp_lengths, rev_styles, reverse=False, temperature=temperature)
    gen_log_probs = model_F(gen_z, inp_tokens, inp_lengths, rev_styles, reverse=True, temperature=temperature)

    gen_soft_tokens = gen_log_probs.exp()
    gen_lengths = get_lengths(gen_soft_tokens.argmax(-1), eos_idx)

    cyc_z = model_F(gen_soft_tokens, inp_tokens, gen_lengths, raw_styles,
                    reverse=False, temperature=temperature)
    cyc_log_probs = model_F(cyc_z, inp_tokens, gen_lengths, raw_styles,
                            reverse=True, temperature=temperature)

    cyc_rec_loss = ce(cyc_log_probs.transpose(1, 2), inp_tokens) * token_mask
    cyc_rec_loss = cyc_rec_loss.sum() / batch_size
    cyc_rec_loss *= config.cyc_factor

    # content loss
    content_loss = ce(cyc_z, gen_z)
    content_loss *= config.content_factor

    # style consistency loss
    style_log_porbs = model_D(gen_soft_tokens, gen_lengths, rev_styles)
    style_labels = rev_styles + 1

    style_loss = nll(style_log_porbs, style_labels)
    style_loss = style_loss.sum() / batch_size
    style_loss *= config.adv_factor
    (content_loss+cyc_rec_loss + style_loss).backward()

    clip_grad_norm_(model_F.parameters(), 5)
    optimizer_F.step()
    return slf_rec_loss.item(), cyc_rec_loss.item(), content_loss.item(), style_loss.item()


def train(config, vocab, model_F, model_D, input_emb, train_iters, test_iters):
    optimizer_F: Adam = optim.Adam(model_F.parameters(), lr=config.lr_F, weight_decay=config.L2)

    f_slf_loss_list = []
    f_cyc_loss_list = []
    f_content_loss_list = []
    f_style_loss_list = []

    global_step = 0
    model_F.train()

    config.save_folder = config.save_path + '/' + str(time.strftime('%b%d%H%M%S', time.localtime()))
    os.makedirs(config.save_folder)
    os.makedirs(config.save_folder + '/ckpts')
    print('Save Path:', config.save_folder)
    print('Training start......')

    def calc_temperature(temperature_config, step):
        num = len(temperature_config)
        for i in range(num):
            t_a, s_a = temperature_config[i]
            if i == num - 1:
                return t_a
            t_b, s_b = temperature_config[i + 1]
            if s_a <= step < s_b:
                k = (step - s_a) / (s_b - s_a)
                temperature = (1 - k) * t_a + k * t_b
                return temperature

    batch_iters = iter(train_iters)

    # training
    while True:
        drop_decay = calc_temperature(config.drop_rate_config, global_step)
        temperature = calc_temperature(config.temperature_config, global_step)

        batch = next(batch_iters)
        f_slf_loss, f_cyc_loss, f_content_loss, f_style_loss = f_step(config, vocab, model_F, model_D, input_emb,


                                                                      optimizer_F, batch, temperature, drop_decay)
        f_slf_loss_list.append(f_slf_loss)
        f_cyc_loss_list.append(f_cyc_loss)
        f_content_loss_list.append(f_content_loss)
        f_style_loss_list.append(f_style_loss)

        global_step += 1

        if global_step % config.log_steps == 0:
            avrg_f_slf_loss = np.mean(f_slf_loss_list)
            avrg_f_cyc_loss = np.mean(f_cyc_loss_list)
            avrg_f_content_loss = np.mean(f_content_loss_list)
            avrg_f_style_loss = np.mean(f_style_loss_list)
            log_str = '[iter {}] ' + \
                      'f_slf_loss: {:.4f}  f_cyc_loss: {:.4f}  ' + \
                      'f_content_loss: {:.4f}  f_style_loss: {:.4f}  temp: {:.4f}  drop: {:.4f}'
            print(log_str.format(
                global_step,
                avrg_f_slf_loss, avrg_f_cyc_loss, avrg_f_content_loss, avrg_f_style_loss,
                temperature, config.inp_drop_prob * drop_decay
            ))

        if global_step % config.eval_steps == 0:
            f_slf_loss_list = []
            f_cyc_loss_list = []
            f_content_loss_list = []
            f_style_loss_list = []

            # save model
            # torch.save(model_F.state_dict(), config.save_folder + '/ckpts/' + str(global_step) + '_F.pth')
            auto_eval(config, vocab, model_F, test_iters, global_step, temperature)


def auto_eval(config, vocab, model_F, test_iters, global_step, temperature):
    model_F.eval()
    vocab_size = len(vocab)
    eos_idx = vocab.stoi['<eos>']

    def inference(data_iter, raw_style):
        gold_text = []
        raw_output = []
        rev_output = []
        for batch in data_iter:
            inp_tokens = batch.text
            inp_lengths = get_lengths(inp_tokens, eos_idx)
            raw_styles = torch.full_like(inp_tokens[:, 0], raw_style)
            rev_styles = 1 - raw_styles

            with torch.no_grad():
                raw_z = model_F(inp_tokens, inp_tokens, inp_lengths, raw_styles, reverse=False, temperature=temperature)
                raw_log_probs = model_F(raw_z, inp_tokens, inp_lengths, raw_styles, reverse=True,
                                        temperature=temperature)

            with torch.no_grad():
                rev_z = model_F(inp_tokens, inp_tokens, inp_lengths, rev_styles, reverse=False, temperature=temperature)
                rev_log_probs = model_F(rev_z, inp_tokens, inp_lengths, rev_styles, reverse=True,
                                        temperature=temperature)

            gold_text += tensor2text(vocab, inp_tokens.cpu())
            raw_output += tensor2text(vocab, raw_log_probs.argmax(-1).cpu())
            rev_output += tensor2text(vocab, rev_log_probs.argmax(-1).cpu())

        return gold_text, raw_output, rev_output

    pos_iter = test_iters.pos_iter
    neg_iter = test_iters.neg_iter

    gold_text, raw_output, rev_output = zip(inference(neg_iter, 0), inference(pos_iter, 1))

    evaluator = Evaluator()
    ref_text = evaluator.yelp_ref

    acc_neg = evaluator.yelp_acc_0(rev_output[0])  # 出来的应该是1
    acc_pos = evaluator.yelp_acc_1(rev_output[1])  # 出来的应该是0
    bleu_neg = evaluator.yelp_ref_bleu_0(rev_output[0])
    bleu_pos = evaluator.yelp_ref_bleu_1(rev_output[1])
    ppl_neg = evaluator.yelp_ppl(rev_output[0])
    ppl_pos = evaluator.yelp_ppl(rev_output[1])

    for k in range(5):
        idx = np.random.randint(len(rev_output[0]))
        print('*' * 20, 'neg sample', '*' * 20)
        print('[gold]', gold_text[0][idx])
        print('[raw ]', raw_output[0][idx])
        print('[rev ]', rev_output[0][idx])
        print('[ref ]', ref_text[0][idx])

    print('*' * 20, '********', '*' * 20)

    for k in range(5):
        idx = np.random.randint(len(rev_output[1]))
        print('*' * 20, 'pos sample', '*' * 20)
        print('[gold]', gold_text[1][idx])
        print('[raw ]', raw_output[1][idx])
        print('[rev ]', rev_output[1][idx])
        print('[ref ]', ref_text[1][idx])

    print('*' * 20, '********', '*' * 20)

    print(('[auto_eval] acc_pos: {:.4f} acc_neg: {:.4f} ' + \
           'bleu_pos: {:.4f} bleu_neg: {:.4f} ' + \
           'ppl_pos: {:.4f} ppl_neg: {:.4f}\n').format(
        acc_pos, acc_neg, bleu_pos, bleu_neg, ppl_pos, ppl_neg,
    ))

    # save output
    save_file = config.save_folder + '/' + str(global_step) + '.txt'
    eval_log_file = config.save_folder + '/eval_log.txt'
    with open(eval_log_file, 'a') as fl:
        print(('iter{:5d}:  acc_pos: {:.4f} acc_neg: {:.4f} ' + \
               'bleu_pos: {:.4f} bleu_neg: {:.4f} ' + \
               'ppl_pos: {:.4f} ppl_neg: {:.4f}\n').format(
            global_step, acc_pos, acc_neg, bleu_pos, bleu_neg, ppl_pos, ppl_neg,
        ), file=fl)
    with open(save_file, 'w') as fw:
        print(('[auto_eval] acc_pos: {:.4f} acc_neg: {:.4f} ' + \
               'bleu_pos: {:.4f} bleu_neg: {:.4f} ' + \
               'ppl_pos: {:.4f} ppl_neg: {:.4f}\n').format(
            acc_pos, acc_neg, bleu_pos, bleu_neg, ppl_pos, ppl_neg,
        ), file=fw)

        for idx in range(len(rev_output[0])):
            print('*' * 20, 'neg sample', '*' * 20, file=fw)
            print('[gold]', gold_text[0][idx], file=fw)
            print('[raw ]', raw_output[0][idx], file=fw)
            print('[rev ]', rev_output[0][idx], file=fw)
            print('[ref ]', ref_text[0][idx], file=fw)

        print('*' * 20, '********', '*' * 20, file=fw)

        for idx in range(len(rev_output[1])):
            print('*' * 20, 'pos sample', '*' * 20, file=fw)
            print('[gold]', gold_text[1][idx], file=fw)
            print('[raw ]', raw_output[1][idx], file=fw)
            print('[rev ]', rev_output[1][idx], file=fw)
            print('[ref ]', ref_text[1][idx], file=fw)

        print('*' * 20, '********', '*' * 20, file=fw)

    model_F.train()
