import os
import numpy as np
import torch

from transformers import BartTokenizer, BartConfig
from transformers import AdamW, get_linear_schedule_with_warmup

from my_datasets.dataloader_single_task import MySingleTaskData

from modules.bart import MyBart
from modules.routing_bart_config import RoutingBartConfig
from modules.routing_bart_v2 import MyRoutingBart
from modules.task2vec import Task2Vec
from modules.utils import squeeze_weights
from utils import trim_batch

from tqdm import tqdm

def run(args, logger):
    tokenizer = BartTokenizer.from_pretrained(args.model)

    train_data, dev_data = load_data(args, logger, tokenizer)

    best_dev_performance = None
    test_performance = None
    best_model_state_dict = None

    if args.do_train:
        if args.moe_init:
            model = load_moe_model(args, logger)
        else:
            model = load_model(args, logger)
        optimizer, scheduler = get_optimizer_and_scheduler(args, logger, model, train_data)
        best_dev_performance, best_model_state_dict = train(args, logger, model, train_data, dev_data, optimizer, scheduler)

    if args.do_predict:
        if args.do_train and best_model_state_dict is not None:
            model = MyBart.from_pretrained(args.model,
                                       state_dict=best_model_state_dict)
            logger.info("Loading checkpoint from CPU")
            if torch.cuda.is_available():
                model.to(torch.device("cuda"))
        else:
            args.checkpoint = os.path.join(args.output_dir, args.predict_checkpoint)
            model = load_model(args, logger)
            logger.info("Loading checkpoint from {}".format(checkpoint))
            
        model.eval()

        test_data = load_predict_data(args, logger, tokenizer)

        test_performance = inference(args, model, test_data, save_predictions=True, verbose=True)
        logger.info("%s on %s data: %.2f" % (test_data.metric, test_data.data_type, test_performance))

    return best_dev_performance, test_performance

def train(args, logger, model, train_data, dev_data, optimizer, scheduler):
    model.train()
    global_step = 0
    train_losses = []
    best_performance = -1.0
    stop_training=False

    if args.moe_avg_route:
        enc_routes0 = torch.ones(model.config.encoder_layers, model.config.router_block_num) / model.config.router_block_num
        dec_routes0 = torch.ones(model.config.decoder_layers, model.config.router_block_num) / model.config.router_block_num
        if torch.cuda.is_available():
            enc_routes0 = enc_routes0.to(torch.device("cuda"))
            dec_routes0 = dec_routes0.to(torch.device("cuda"))

    logger.info("Starting training!")
    for epoch in range(int(args.num_train_epochs)):
        for batch in tqdm(train_data.dataloader, desc="Epoch {}".format(epoch), disable=args.quiet):
            global_step += 1
            if torch.cuda.is_available():
                batch = [b.to(torch.device("cuda")) for b in batch]
            
            pad_token_id = train_data.tokenizer.pad_token_id

            batch[0], batch[1] = trim_batch(batch[0], pad_token_id, batch[1])
            batch[2], batch[3] = trim_batch(batch[2], pad_token_id, batch[3])

            if args.moe_avg_route:
                bsz = batch[0].shape[0]
                enc_routes = enc_routes0.expand(bsz, -1, -1).transpose(0,1)
                dec_routes = dec_routes0.expand(bsz, -1, -1).transpose(0,1)

                loss = model(input_ids=batch[0], attention_mask=batch[1], 
                    decoder_input_ids=batch[2], decoder_attention_mask=batch[3],
                    block_distribution=enc_routes,
                    decoder_block_distribution=dec_routes,
                    is_training=True)
            else:
                loss = model(input_ids=batch[0], attention_mask=batch[1],
                            decoder_input_ids=batch[2], decoder_attention_mask=batch[3],
                            is_training=True)

            if args.n_gpu > 1:
                loss = loss.mean() # mean() to average on multi-gpu.
            if torch.isnan(loss).data:
                logger.info("Stop training because loss=%s" % (loss.data))
                stop_training=True
                break
            train_losses.append(loss.detach().cpu())
            loss.backward()

            if global_step % args.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                optimizer.step()    # We have accumulated enough gradients
                scheduler.step()
                model.zero_grad()

            if global_step % args.eval_period == 0:
                model.eval()
                curr_performance = inference(args, model if args.n_gpu==1 else model.module, dev_data)
                logger.info("Step %d Train loss %.2f %s %s on epoch=%d" % (
                        global_step,
                        np.mean(train_losses),
                        dev_data.metric,
                        curr_performance,
                        epoch))
                train_losses = []
                if best_performance < curr_performance:
                    best_model_state_dict = {k:v.cpu() for (k, v) in model.state_dict().items()}
                    # model_state_dict = {k:v.cpu() for (k, v) in model.state_dict().items()}
                    # torch.save(model_state_dict, os.path.join(args.output_dir, "best-model.pt"))
                    logger.info("Not saving model with best %s: %s -> %s on epoch=%d, global_step=%d" % \
                            (dev_data.metric, best_performance, curr_performance, epoch, global_step))
                    best_performance = curr_performance
                    wait_step = 0
                    stop_training = False
                else:
                    wait_step += 1
                    if wait_step >= args.wait_step:
                        stop_training = True
                        break

                model.train()

            if global_step >= args.total_steps:
                stop_training = True
                break
                
        if stop_training:
            break

    # model_state_dict = {k:v.cpu() for (k, v) in model.state_dict().items()}
    # torch.save(model_state_dict, os.path.join(args.output_dir, "last-model.pt"))
    return best_performance, best_model_state_dict

def inference(args, model, dev_data, save_predictions=False, verbose=False):

    if args.moe_avg_route:
        enc_routes0 = torch.ones(model.config.encoder_layers, model.config.router_block_num) / model.config.router_block_num
        dec_routes0 = torch.ones(model.config.decoder_layers, model.config.router_block_num) / model.config.router_block_num
        if torch.cuda.is_available():
            enc_routes0 = enc_routes0.to(torch.device("cuda"))
            dec_routes0 = dec_routes0.to(torch.device("cuda"))

    predictions = []
    bos_token_id = dev_data.tokenizer.bos_token_id
    for i, batch in enumerate(dev_data.dataloader):
        if torch.cuda.is_available():
            batch = [b.to(torch.device("cuda")) for b in batch]
        pad_token_id = dev_data.tokenizer.pad_token_id
        batch[0], batch[1] = trim_batch(batch[0], pad_token_id, batch[1])

        if args.moe_avg_route:

            bsz = batch[0].shape[0]
            enc_routes = enc_routes0.expand(bsz, -1, -1).transpose(0,1)
            dec_routes = dec_routes0.expand(bsz, -1, -1).transpose(0,1)

            outputs = model.generate(input_ids=batch[0],
                                    attention_mask=batch[1],
                                    block_distribution=enc_routes,
                                    decoder_block_distribution=dec_routes,
                                    num_beams=dev_data.args.num_beams,
                                    max_length=dev_data.args.max_output_length,
                                    decoder_start_token_id=model.config.bos_token_id,
                                    early_stopping=dev_data.gen_early_stop,
                                    use_cache=True,
                                    use_sparse=False,
                                    )
        else:
            outputs = model.generate(input_ids=batch[0],
                                    attention_mask=batch[1],
                                    num_beams=dev_data.args.num_beams,
                                    max_length=dev_data.args.max_output_length,
                                    decoder_start_token_id=model.config.bos_token_id,
                                    early_stopping=dev_data.gen_early_stop,)
        for input_, output in zip(batch[0], outputs):
            pred = dev_data.decode(output)
            predictions.append(pred)
    if save_predictions:
        dev_data.save_predictions(predictions)
    return dev_data.evaluate(predictions, verbose=verbose)

def load_data(args, logger, tokenizer):
    train_data = MySingleTaskData(logger, args, args.train_file, data_type="train", is_training=True)
    dev_data = MySingleTaskData(logger, args, args.dev_file, data_type="dev", is_training=False)

    train_data.load_dataset(tokenizer)
    train_data.load_dataloader()

    dev_data.load_dataset(tokenizer)
    dev_data.load_dataloader()

    return train_data, dev_data

def load_predict_data(args, logger, tokenizer):
    data_type = "test" if "test" in args.test_file else "dev"
    test_data = MySingleTaskData(logger, args, args.test_file, data_type=data_type, is_training=False, task_name=args.dataset)
    test_data.load_dataset(tokenizer)
    test_data.load_dataloader()
    return test_data

def load_model(args, logger):
    if args.checkpoint is not None and args.checkpoint != "None":
        def convert_to_single_gpu(state_dict):
            def _convert(key):
                if key.startswith('module.'):
                    return key[7:]
                return key
            return {_convert(key):value for key, value in state_dict.items()}
        model = MyBart.from_pretrained(args.model,
                                        state_dict=convert_to_single_gpu(torch.load(args.checkpoint)))
    else:
        model = MyBart.from_pretrained(args.model)

    if args.n_gpu>1:
        model = torch.nn.DataParallel(model)

    if torch.cuda.is_available():
        model.to(torch.device("cuda"))

    return model

def load_moe_model(args, logger):
    # format is like `data/crossfit_data_v2/boolq/` so we need the "-2"
    task_name = args.dataset

    # get a new vanilla bart model
    new_model = MyBart.from_pretrained(args.model)

    # load config, model and task_model
    config_path = os.path.join(args.init_dir, "config.json")
    config = RoutingBartConfig.from_pretrained(config_path)
    config.encoder_vanilla_layers = [int(item) for item in config.encoder_vanilla_layers.split(",")] if config.encoder_vanilla_layers else []
    config.decoder_vanilla_layers = [int(item) for item in config.decoder_vanilla_layers.split(",")] if config.decoder_vanilla_layers else []
    
    model = MyRoutingBart(config)
    model_path = os.path.join(args.init_dir, args.checkpoint_name, "model.pt")
    model.load_state_dict(torch.load(model_path))

    task_model_path = os.path.join(args.init_dir, args.checkpoint_name)
    task_model = Task2Vec(task_model_path)

    model.eval()
    task_model.eval()
    model.set_gumbel_temperature(1.0)
    model.set_router_mode(config.router_mode)

    if args.moe_avg_route:
        # return the full model; no need for squeezing
        if torch.cuda.is_available():
            model.to(torch.device("cuda"))
        return model

    if args.moe_random_route:
        # randomly select 1
        n_expert = config.router_block_num

        enc_routes = np.zeros((config.encoder_layers, n_expert))
        for j in range(config.encoder_layers):
            selected = np.random.choice(n_expert, 1, replace=False)
            enc_routes[j,selected] = 1.0
        enc_routes = torch.tensor(enc_routes)

        dec_routes = np.zeros((config.decoder_layers, n_expert))
        for j in range(config.decoder_layers):
            selected = np.random.choice(n_expert, 1, replace=False)
            dec_routes[j,selected] = 1.0
        dec_routes = torch.tensor(dec_routes)

    else:
        # get task emb for the current task, then get the routes
        task_id = task_model.taskname2id(task_name)
        task_embed = task_model(task_id)
        enc_routes, dec_routes = model.get_routes(task_embed, separate=True)

    # use the routes to get the experts and initialize the vanilla model
    squeeze_weights(config, new_model, model, enc_routes, dec_routes)
    
    # free the memory
    del model
    del task_model

    if torch.cuda.is_available():
        new_model.to(torch.device("cuda"))

    return new_model

def get_optimizer_and_scheduler(args, logger, model, train_data):
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler =  get_linear_schedule_with_warmup(optimizer,
                                    num_warmup_steps=args.warmup_steps,
                                    num_training_steps=args.total_steps)
    return optimizer, scheduler