import torch
import os
from transformers import AdamW, get_linear_schedule_with_warmup
import torch.nn as nn
import logging
from tqdm import tqdm
from train_utils import *
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
logger = logging.getLogger(__name__)

def prepare_for_training(args, model, train_iter):
    optimizer = AdamW(model.parameters(), lr=args.learning_rate, correct_bias=True)
    t_total = len(train_iter) * args.epochs
    if args.use_scheduler:
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
    else:
        scheduler = None

    return model, optimizer, scheduler

def compute_loss(logits, target_tokens, kl_loss=None, beta=None, ignore_index=50256):
    loss_fn = nn.CrossEntropyLoss(ignore_index=ignore_index)
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = target_tokens[..., 1:].contiguous()
    
    ce_loss = loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    if kl_loss is not None:
        loss = ce_loss + beta * kl_loss
    else:
        loss = ce_loss
    return loss, ce_loss, kl_loss

def train(model, train_iter, valid_iter, args):
    logging.info('begin trainging...')
    model, optimizer, scheduler = prepare_for_training(args, model, train_iter)
    if args.cycle_annealing or args.linear_annealing:
        beta = 1e-5
        beta_0 = 1e-5
    else:
        beta = 1
    global_step = 0
    
    one_epoch_step = len(train_iter) // args.gradient_accumulation_steps
    beta_zero = beta_increase = args.cycle_iters // 2
    running_loss = 0
    running_ce_loss = 0
    running_kl_loss = 0
    for epoch in range(1 + args.begin_epoch, args.epochs + args.begin_epoch + 1):
        model.train()
        for i, inputs in enumerate(train_iter):
            if args.model_type == 'bartbase':
                output = model(
                    input_ids=inputs['input_ids'],
                    attention_mask=inputs['attention_mask'],
                    decoder_input_ids=inputs['decoder_input_ids'],
                    labels=inputs['labels'],
                )
                loss = output.loss
            elif args.model_type == 'gpt2':
                if 'labels' in inputs:
                    labels = inputs['labels']
                else:
                    labels = inputs['input_ids']
                output = model(
                    input_ids=inputs['input_ids'],
                    attention_mask=inputs['attention_mask'],
                    labels=labels,
                )
                loss = output.loss
            else:
                ce_loss, kl_loss, _, _ = model(**inputs)
                loss = ce_loss + beta * kl_loss
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            loss = loss.mean()
            loss.backward()
            running_loss += loss.item()
            if args.model_type not in ['gpt2', 'bartbase']:
                running_ce_loss += ce_loss.mean().item() / args.gradient_accumulation_steps
                running_kl_loss += kl_loss.mean().item() / args.gradient_accumulation_steps

            if (i + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
                if scheduler is not None:
                    scheduler.step()

                global_step += 1
                if args.cycle_annealing:
                    one_period = epoch % args.cycle_iters
                    if one_period < beta_zero:
                        beta = beta_0
                    else:
                        beta = min(1.0, beta + (1 - beta_0) / (beta_increase * one_epoch_step / 2))

                if global_step % args.log_step == 0:
                    logging.info('training loss: step [{}~{}], loss {}, ce_loss {}, kl_loss {}, lr {}, beta {}'.
                        format(global_step - args.log_step, global_step, running_loss / args.log_step, running_ce_loss / args.log_step, running_kl_loss / args.log_step,
                                optimizer.param_groups[0]['lr'], beta))
                    running_loss = 0
                    running_kl_loss = 0
                    running_ce_loss = 0

        valid(model, valid_iter, epoch, args, beta)
        save(model, args, epoch)
    logging.info('training finished')

def valid(model, valid_iter, epoch, args, beta=1):
    model.eval()
    with torch.no_grad():
        valid_loss = 0
        valid_kl_loss = 0
        valid_ce_loss = 0
        for inputs in valid_iter:
            if args.model_type == 'bartbase':
                output = model(
                    input_ids=inputs['input_ids'],
                    attention_mask=inputs['attention_mask'],
                    decoder_input_ids=inputs['decoder_input_ids'],
                    labels=inputs['labels'],
                )
                loss = output.loss
            elif args.model_type == 'gpt2':
                if 'labels' in inputs:
                    labels = inputs['labels']
                else:
                    labels = inputs['input_ids']
                output = model(
                    input_ids=inputs['input_ids'],
                    attention_mask=inputs['attention_mask'],
                    labels=labels,
                )
                loss = output.loss
            else:
                ce_loss, kl_loss, _, _ = model(**inputs)
                loss = ce_loss + beta * kl_loss
            loss = loss.mean()
            valid_loss += loss.item()
            if args.model_type not in ['gpt2', 'bartbase']:
                valid_ce_loss += ce_loss.mean().item()
                valid_kl_loss += kl_loss.mean().item()
        valid_loss = valid_loss / len(valid_iter)
        valid_ce_loss = valid_ce_loss / len(valid_iter)
        valid_kl_loss = valid_kl_loss / len(valid_iter)
        logging.info('valid result: epoch {}, loss {}, ce_loss {}, kl {}'.format(epoch, valid_loss, valid_ce_loss, valid_kl_loss))
        
        if args.eval_metrics:
            ppl, elbo, nll, kl = calc_iwnll(model, valid_iter, ns=args.sample_times)
            mi = calc_mi(model, valid_iter)
            au = calc_au(model, valid_iter)
            logging.info('valid result: epoch {}, ppl {}, elbo {}, nll {}, kl {}'.format(epoch, ppl, elbo, nll, kl))
            logging.info('valid result: epoch {}, mi {}, au {}'.format(epoch, mi, au))

def save(model, args, epoch):
    save_path = os.path.join(args.output_dir, args.model_name, 'model_epoch_{}.pt'.format(epoch))
    if not os.path.exists(os.path.join(args.output_dir, args.model_name)):
        os.makedirs(os.path.join(args.output_dir, args.model_name), exist_ok=True)
    try:
        model_to_save = model.module
    except:
        model_to_save = model
    torch.save(model_to_save.state_dict(), save_path)

def generate(model, test_iter, tokenizer, args): 
    if args.dataset_type in ['cvae', 'seq2seq', 'wp']:
        has_condition = "conditional"
    else:
        has_condition = "unconditional"
    if args.top_k > 0:
        generate_param = "topk_{}".format(args.top_k)
    else:
        generate_param = "beamsearch_{}".format(args.num_beams)
    logging.info('{} generate with {}'.format(has_condition, generate_param))
    def filter_sen(sen):
        sen = sen.replace('<sep>', '')
        sen = sen.replace('<s>', '')
        sen = sen.replace('</s>', '')
        sen = sen.replace('<pad>', '')
        sen = sen.replace('<|endoftext|>', '')
        sen = sen.replace('<eos>', '')
        sen = ' '.join(sen.split())
        return sen
    model.eval()
    try:
        model.decoder.config.is_encoder_decoder = False
    except:
        pass
    output_list = []
    target_list = []
    source_list = []
    with torch.no_grad():
        for inputs in tqdm(test_iter):
            if args.dataset_type == 'seq2seq':
                source = inputs['input_ids']
                target = inputs['decoder_input_ids']
                ans = model.generate(
                    inputs['input_ids'], 
                    num_beams=args.num_beams, 
                    max_length=args.max_length, 
                    repetition_penalty=args.repetition_penalty
                )
            else:
                target = inputs['input_ids']
                if args.dataset_type in ['cvae', 'seq2seq', 'wp']:
                    source = inputs['condition']
                batch_size = target.size(0)
                device = target.device
                input_ids = target[:, 0].unsqueeze(1)
                model_kwargs = {}
                if args.dataset_type == 'cvae':
                    encoder_attention_mask = inputs['condition_mask']
                    prior_latent = model.get_prior(batch_size, device, condition=source, condition_mask=encoder_attention_mask)
                    encoder_hidden_states = model.get_encoder_hidden_states(source, encoder_attention_mask)
                    if args.top_k <= 0:
                        encoder_hidden_states = encoder_hidden_states.repeat_interleave(args.num_beams, dim=0)
                        encoder_attention_mask = encoder_attention_mask.repeat_interleave(args.num_beams, dim=0)
                    model_kwargs['encoder_hidden_states'] = encoder_hidden_states
                    model_kwargs['attention_mask'] = encoder_attention_mask
                elif args.dataset_type == 'wp':
                    if args.model_type != 'gpt2':
                        prior_latent = model.get_prior(batch_size, device, condition=inputs['condition'], condition_mask=inputs['condition_mask'])
                        model_kwargs['attention_mask'] = inputs['condition_mask']
                    else:
                        prior_latent = None
                    input_ids = inputs['condition']
                else:
                    if args.model_type != 'gpt2':
                        prior_latent = model.get_prior(batch_size, device)
                    else:
                        prior_latent = None
                if args.model_type != 'gpt2':
                    gen_model = model.decoder
                else:
                    gen_model = model
                if args.top_k > 0:
                    ans = gen_model.generate(
                        input_ids, 
                        latent=prior_latent, 
                        bos_token_id=tokenizer.bos_id, 
                        eos_token_id=tokenizer.eos_id, 
                        pad_token_id=tokenizer.pad_id, 
                        do_sample=True,
                        top_k=args.top_k, 
                        top_p=args.top_p, 
                        min_length=input_ids.size(-1) + 3, 
                        max_length=min(args.max_length, 1024),
                        repetition_penalty=args.repetition_penalty, 
                        **model_kwargs,
                    )
                else:
                    if prior_latent is not None:
                        if isinstance(prior_latent, tuple):
                            latent = [item.repeat_interleave(args.num_beams, dim=0) for item in prior_latent]
                        else:
                            latent = prior_latent.repeat_interleave(args.num_beams, dim=0)
                    else:
                        latent = None
                    ans = gen_model.generate(
                        input_ids, 
                        latent=latent, 
                        bos_token_id=tokenizer.bos_id, 
                        eos_token_id=tokenizer.eos_id, 
                        pad_token_id=tokenizer.pad_id, 
                        num_beams=args.num_beams, 
                        min_length=input_ids.size(-1) + 3, 
                        max_length=min(args.max_length, 1024), 
                        repetition_penalty=args.repetition_penalty, 
                        **model_kwargs,
                    )
            ans = ans.cpu().numpy()
            if args.dataset_type in ['cvae', 'seq2seq', 'wp']:
                target = target.cpu().numpy()
                source = source.cpu().numpy()
            for i in range(len(ans)):
                text_ans = tokenizer.decode(ans[i], clean_up_tokenization_spaces=False)
                text_ans = filter_sen(text_ans)
                if len(text_ans) > 0:
                    output_list.append(text_ans)
                    if args.dataset_type in ['cvae', 'seq2seq', 'wp']:
                        target_text = tokenizer.decode(target[i], clean_up_tokenization_spaces=False)
                        target_text = filter_sen(target_text)
                        target_list.append(target_text)
                        source_text = tokenizer.decode(source[i], clean_up_tokenization_spaces=False)
                        source_text = filter_sen(source_text)
                        source_list.append(source_text)

    save_dir = os.path.join(args.generation_output_dir, args.model_name)
    file_name = '{}_output_{}_epoch_{}_outputs.txt'.format(has_condition, generate_param, args.begin_epoch)
    logging.info('generation output save at {}'.format(save_dir))
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    with open(os.path.join(save_dir, file_name), 'w') as f:
        f.write('\n'.join(output_list))
    if args.dataset_type in ['cvae', 'seq2seq', 'wp']:
        file_name = '{}_output_{}_epoch_{}_targets.txt'.format(has_condition, generate_param, args.begin_epoch)
        with open(os.path.join(save_dir, file_name), 'w') as f:
            f.write('\n'.join(target_list))
        file_name = '{}_output_{}_epoch_{}_sources.txt'.format(has_condition, generate_param, args.begin_epoch)
        with open(os.path.join(save_dir, file_name), 'w') as f:
            f.write('\n'.join(source_list))

def interpolating(model, tokenizer, sen1, sen2, args):
    sen1_token = [tokenizer.bos_id] + tokenizer.encode(sen1)
    sen1_tensor = torch.LongTensor(sen1_token).unsqueeze(0).to(args.device)
    sen2_token = [tokenizer.bos_id] + tokenizer.encode(sen2)
    sen2_tensor = torch.LongTensor(sen2_token).unsqueeze(0).to(args.device)

    mean1, sigma1 = model.get_encode_states(sen1_tensor, None)
    dist1 = Normal(mean1, sigma1)
    latent1, _ = dist1.sample()

    mean2, sigma2 = model.get_encode_states(sen2_tensor, None)
    dist2 = Normal(mean2, sigma2)
    latent2, _ = dist2.sample()

    for i in range(1, 10):
        tau = i / 10
        latent = tau * latent1 + (1-tau) * latent2
        latent = torch.chunk(latent, 12, dim=-1)
        latent = [item.repeat_interleave(args.num_beams, dim=0) for item in latent]

        input_ids = sen1_tensor[:, 0].unsqueeze(1)
        ans = model.decoder.generate(
            input_ids, 
            latent=latent, 
            bos_token_id=tokenizer.bos_id, 
            eos_token_id=tokenizer.eos_id, 
            pad_token_id=tokenizer.pad_id, 
            num_beams=args.num_beams, 
            #do_sampling=True, 
            #top_k=50,
            min_length=input_ids.size(-1) + 3, 
            max_length=min(args.max_length, 1024), 
            repetition_penalty=args.repetition_penalty, 
        )
        ans = ans.cpu().numpy()
        text_ans = tokenizer.decode(ans[0], clean_up_tokenization_spaces=False)
        print('tau ====== ', tau)
        print(text_ans)

def visual_attention(model, tokenizer, sen1, args):
    model.eval()
    sen1_token = [tokenizer.bos_id] + tokenizer.encode(sen1)
    sen1_tensor = torch.LongTensor(sen1_token).unsqueeze(0).to(args.device)
    attention = model.get_attention(sen1_tensor, attention_mask=None)
    
    for layer in range(len(attention)):
        attn = attention[layer][0]
        average = torch.mean(attn, dim=0)
        attn_average = average.transpose(1, 0).cpu().detach().numpy()
        text = ['<s>'] + sen1.split()
        df = pd.DataFrame(attn_average, index=['z'] + text, columns=text)
        f, ax= plt.subplots()

        sns.heatmap(data=df, ax=ax, cmap='Blues')
        ax.invert_yaxis()
        label_y = ax.get_yticklabels()
        plt.setp(label_y , rotation = 360)
        plt.title('Attention Layer: {}, average'.format(layer))
        
        save_path = os.path.join('./draw/attention_layer{}_average_head.eps'.format(layer))
        f.savefig(save_path, bbox_inches = 'tight',dpi=600,format='eps')
        save_path = os.path.join('./draw/attention_layer{}_average_head.png'.format(layer))
        f.savefig(save_path, bbox_inches = 'tight')

    for layer in range(len(attention)):
        attn = attention[layer][0]
        z_weight = attn[:, :, 0]
        max_head = torch.argmax(torch.mean(z_weight, dim=-1)).item()
        attn_weight = attn[max_head].transpose(1, 0).cpu().detach().numpy()
        text = ['<s>'] + sen1.split()
        df = pd.DataFrame(attn_weight, index=['z'] + text, columns=text)
        f, ax= plt.subplots()

        sns.heatmap(data=df, ax=ax, cmap='Blues')
        ax.invert_yaxis()
        label_y = ax.get_yticklabels()
        plt.setp(label_y , rotation = 360)
        plt.title('Attention Layer: {}, max'.format(layer))
        
        save_path = os.path.join('./draw/attention_layer{}_max_head.eps'.format(layer))
        f.savefig(save_path, bbox_inches = 'tight',dpi=600,format='eps')
        save_path = os.path.join('./draw/attention_layer{}_max_head.png'.format(layer))
        f.savefig(save_path, bbox_inches = 'tight')

    for layer in range(len(attention)):
        attn = attention[layer][0]
        z_weight = attn[:, :, 0]
        min_head = torch.argmin(torch.mean(z_weight, dim=-1)).item()
        attn_weight = attn[min_head].transpose(1, 0).cpu().detach().numpy()
        text = ['<s>'] + sen1.split()
        df = pd.DataFrame(attn_weight, index=['z'] + text, columns=text)
        #df = pd.DataFrame(attn_weight, sentence, columns=sentence)
        f, ax= plt.subplots()

        sns.heatmap(data=df, ax=ax, cmap='Blues')
        ax.invert_yaxis()
        label_y = ax.get_yticklabels()
        plt.setp(label_y , rotation = 360)
        plt.title('Attention Layer: {}, min'.format(layer))
        
        save_path = os.path.join('./draw/attention_layer{}_min_head.eps'.format(layer))
        f.savefig(save_path, bbox_inches = 'tight',dpi=600,format='eps')
        save_path = os.path.join('./draw/attention_layer{}_min_head.png'.format(layer))
        f.savefig(save_path, bbox_inches = 'tight')