import os
import math
import torch

import pandas as pd
import numpy as np

from tqdm import tqdm

from transformers import BartTokenizer, BartConfig, BartForConditionalGeneration
from transformers import AdamW, get_linear_schedule_with_warmup

from my_datasets.dataloader import MyData
from my_datasets.dataloader_multiple_tasks_evaluation import MyMultipleTasksEvaluationData
from modules.bart import MyBart
from utils import trim_batch, get_tasks_list


def run(args, logger):
    tokenizer = BartTokenizer.from_pretrained(args.model)

    all_tasks = get_tasks_list(args.custom_tasks_splits, "train")

    if args.do_train:

        logger.info("Training on the following tasks: {}".format(all_tasks))

        train_data = MyData(logger, args, args.train_dir, tasks=all_tasks, data_type="train", is_training=True)
        train_data.load_dataset(tokenizer, mode="simple")
        train_data.load_dataloader(mode="simple")

        dev_data = MyData(logger, args, args.train_dir, tasks=all_tasks, data_type="dev", is_training=False)
        dev_data.load_dataset(tokenizer, mode="simple")
        dev_data.load_dataloader(mode="simple")

        dev_data2 = MyMultipleTasksEvaluationData(logger, args, args.predict_dir, split="dev", tasks=all_tasks)
        dev_data2.load_dataset(tokenizer)
        dev_data2.load_dataloader()

        # config = BartConfig.from_pretrained("facebook/bart-base")

        if args.checkpoint is not None:
            def convert_to_single_gpu(state_dict):
                def _convert(key):
                    if key.startswith('module.'):
                        return key[7:]
                    return key
                return {_convert(key):value for key, value in state_dict.items()}
            model = MyBart.from_pretrained(args.model,
                                           state_dict=convert_to_single_gpu(torch.load(args.checkpoint)))
        else:
            model = MyBart.from_pretrained(args.model)

        if torch.cuda.is_available():
            model.to(torch.device("cuda"))

        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.learning_rate},
            {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.learning_rate}
        ]

        optimizer = AdamW(optimizer_grouped_parameters, eps=args.adam_epsilon)

        steps_per_epoch = math.ceil(len(train_data.dataloader) / args.gradient_accumulation_steps)
        args.warmup_steps = args.total_steps * args.warmup_ratio
        scheduler =  get_linear_schedule_with_warmup(optimizer,
                                        num_warmup_steps=args.warmup_steps,
                                        num_training_steps=args.total_steps)
        logger.info("#Batches per epoch={}, #Steps per epoch={}, #Total steps={}, #Warmup steps={}".format(
            len(train_data.dataloader), steps_per_epoch, args.total_steps, args.warmup_steps
        ))

        train(args, logger, model, train_data, dev_data, dev_data2, optimizer, scheduler)

    if args.do_predict:

        checkpoint = os.path.join(args.output_dir, "best-model.pt" if args.do_train else args.predict_checkpoint)
        def convert_to_single_gpu(state_dict):
            def _convert(key):
                if key.startswith('module.'):
                    return key[7:]
                return key
            return {_convert(key):value for key, value in state_dict.items()}
        model = MyBart.from_pretrained(args.model,
                                    state_dict=convert_to_single_gpu(torch.load(checkpoint)))
        logger.info("Loading checkpoint from {}".format(checkpoint))
        if torch.cuda.is_available():
            model.to(torch.device("cuda"))

        dev_data = MyMultipleTasksEvaluationData(logger, args, args.predict_dir, split="dev", tasks=all_tasks)
        dev_data.load_dataset(tokenizer)
        dev_data.load_dataloader()

        df_dev_performance, avg_performance = predict(args, logger, model, dev_data)
        logger.info("[Eval] Dev average performance: {}".format(avg_performance))
        df_dev_performance.to_csv(os.path.join(args.output_dir, "eval-dev-performance.csv"))

        test_data = MyMultipleTasksEvaluationData(logger, args, args.predict_dir, split="test", tasks=all_tasks)
        test_data.load_dataset(tokenizer)
        test_data.load_dataloader()

        df_test_performance, avg_performance = predict(args, logger, model, test_data)
        logger.info("[Eval] Test average performance: {}".format(avg_performance))
        df_test_performance.to_csv(os.path.join(args.output_dir, "eval-test-performance.csv"))


def train(args, logger, model, train_data, dev_data, dev_data2, optimizer, scheduler):
    model.train()
    global_batch = 0
    global_step = 0
    total_target_tokens = 0
    train_losses = []
    all_dev_losses = []
    best_avg_performance = -1.0
    stop_training = False

    df = pd.DataFrame(columns=["global_steps", "train_loss", "dev_loss", "dev_performance"])

    os.makedirs(os.path.join(args.output_dir, "dev_performance_logs"), exist_ok=True)

    logger.info("Starting training!")
    for epoch in range(int(args.num_train_epochs)):
        for batch in tqdm(train_data.dataloader, desc="Epoch {}".format(epoch)):

            global_batch += 1

            if torch.cuda.is_available():
                batch = [b.to(torch.device("cuda")) for b in batch]
            
            pad_token_id = train_data.tokenizer.pad_token_id

            batch[0], batch[1] = trim_batch(batch[0], pad_token_id, batch[1])
            batch[2], batch[3] = trim_batch(batch[2], pad_token_id, batch[3])

            loss = model(input_ids=batch[0], attention_mask=batch[1], 
                decoder_input_ids=batch[2], decoder_attention_mask=batch[3],
                is_training=True)

            total_target_tokens += torch.sum(torch.sum(batch[3])).item()

            if torch.isnan(loss).data:
                logger.info("Stop training because loss=%s" % (loss.data))
                stop_training=True
                break

            train_losses.append(loss.detach().cpu())
            loss.backward()

            if global_batch % args.gradient_accumulation_steps == 0:

                global_step += 1
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                optimizer.step()    # We have accumulated enough gradients
                scheduler.step()
                model.zero_grad()

                if global_step % args.eval_period == 0:
                    model.eval()

                    avg_train_loss = np.sum(train_losses)/total_target_tokens

                    avg_dev_loss = validate(args, logger, model, dev_data)

                    df_dev_performance, avg_performance = predict(args, logger, model, dev_data2)  

                    df_dev_performance.to_csv(os.path.join(args.output_dir, "dev_performance_logs", "{}-steps.csv".format(global_step)))

                    logger.info("Step {}: train loss: {}, dev loss: {}, avg dev performance: {}".format(
                        global_step, np.mean(avg_train_loss), avg_dev_loss, avg_performance)
                    )
                    
                    df.loc[len(df.index)] = [global_step, avg_train_loss, avg_dev_loss, avg_performance]
                    df.to_csv(os.path.join(args.output_dir, "dev_loss.csv"))

                    if avg_performance > best_avg_performance:
                        logger.info("Saving model with best dev avg performance: %s -> %s at epoch=%d, global_step=%d" % \
                            (best_avg_performance, avg_performance, epoch, global_step))
                        best_avg_performance = avg_performance

                        model_state_dict = {k:v.cpu() for (k, v) in model.state_dict().items()}
                        torch.save(model_state_dict, os.path.join(args.output_dir, "best-model.pt"))
   
                    model.train()
                    total_target_tokens = 0
                    train_losses = []
                    
                if global_step % args.save_period == 0:

                    # create a subdirectory
                    directory_name = "{}-steps".format(global_step)
                    save_path = os.path.join(args.output_dir, directory_name)

                    if not os.path.exists(save_path):
                        os.makedirs(save_path, exist_ok=True)

                    logger.info("Checkpoint at step {} saved to {}".format(global_step, save_path))

                    # save main model
                    model_state_dict = {k:v.cpu() for (k, v) in model.state_dict().items()}
                    torch.save(model_state_dict, os.path.join(save_path, "model.pt"))
                    # save task vecs
                    task_model.save_to_disk(save_path)

            if global_step >= args.total_steps:
                stop_training = True
                break

        if stop_training:
            break

    # logger.info("Dev losses: {}".format(all_dev_losses))

    # save main model
    model_state_dict = {k:v.cpu() for (k, v) in model.state_dict().items()}
    torch.save(model_state_dict, os.path.join(args.output_dir, "last-model.pt"))


def predict(args, logger, model, predict_data):

    df = pd.DataFrame(columns=["task_prefix", "metric", "performance"])

    for data in tqdm(predict_data):

        predictions = []

        for i, batch in enumerate(data.dataloader):
            if torch.cuda.is_available():
                batch = [b.to(torch.device("cuda")) for b in batch]
            pad_token_id = data.tokenizer.pad_token_id
            batch[0], batch[1] = trim_batch(batch[0], pad_token_id, batch[1])
            bsz = batch[0].shape[0]

            outputs = model.generate(input_ids=batch[0],
                                    attention_mask=batch[1],
                                    num_beams=data.args.num_beams,
                                    max_length=data.args.max_output_length,
                                    decoder_start_token_id=model.config.bos_token_id,
                                    early_stopping=data.gen_early_stop,
                                    use_cache=True,
                                    )

            for input_, output in zip(batch[0], outputs):
                pred = data.decode(output)
                predictions.append(pred)

        df.loc[len(df.index)] = [data.task_name, data.metric, data.evaluate(predictions)]

    return df, np.mean(df["performance"])


def validate(args, logger, model, eval_data):

    bsz = args.predict_batch_size

    eval_losses = []
    total_target_tokens = 0

    for batch in tqdm(eval_data.dataloader, desc="Eval"):

        if torch.cuda.is_available():
            batch = [b.to(torch.device("cuda")) for b in batch]

        pad_token_id = eval_data.tokenizer.pad_token_id

        batch[0], batch[1] = trim_batch(batch[0], pad_token_id, batch[1])
        batch[2], batch[3] = trim_batch(batch[2], pad_token_id, batch[3])

        total_target_tokens += torch.sum(torch.sum(batch[3])).item()

        with torch.no_grad():
            loss = model(input_ids=batch[0], attention_mask=batch[1], 
                decoder_input_ids=batch[2], decoder_attention_mask=batch[3], is_training=True)
            loss = loss.detach().cpu()

        eval_losses.append(loss)

    avg_loss = np.sum(eval_losses) / total_target_tokens

    return avg_loss

