import os
import numpy as np
import torch
import pandas as pd

from torch.optim import SGD

from transformers import BartTokenizer, BartConfig
from transformers import AdamW, get_linear_schedule_with_warmup

from my_datasets.dataloader_single_task import MySingleTaskData

from modules.routing_bart_config import RoutingBartConfig
from modules.routing_bart_v2 import MyRoutingBart
from modules.task2vec import Task2Vec
from modules.utils import initialize_weights

from utils import trim_batch, get_gumbel_temperature, freeze_params, load_saved_checkpoint, get_tasks_list, get_gumbel_temperature
from tqdm import tqdm

def run(args, logger):
    
    tokenizer = BartTokenizer.from_pretrained(args.model)

    train_data = MySingleTaskData(logger, args, args.train_file, data_type="train", is_training=True)
    dev_data = MySingleTaskData(logger, args, args.dev_file, data_type="dev", is_training=False)

    train_data.load_dataset(tokenizer)
    train_data.load_dataloader()

    dev_data.load_dataset(tokenizer)
    dev_data.load_dataloader()

    task_model, model = load_saved_checkpoint(args.checkpoint, logger)
    model.set_router_mode(args.router_mode)
    model.set_gumbel_temperature(args.minimum_tau)
    task_model.eval()
    model.eval()
    task_id = task_model.taskname2id(train_data.task_name)

    if torch.cuda.is_available():
        model.to(torch.device("cuda"))
        task_model.to(torch.device("cuda"))
        task_id = task_id.to(torch.device("cuda"))

    df = route_analysis(
        args, logger, task_model, model, task_id, train_data, dev_data
    )

    return df

def get_loss(model, enc_routes0, dec_routes0, data):

    total_target_tokens = 0
    total_loss = 0.0

    for batch in data.dataloader:
        
        bsz = batch[0].shape[0]
        enc_routes = enc_routes0.expand(bsz, -1, -1).transpose(0,1)
        dec_routes = dec_routes0.expand(bsz, -1, -1).transpose(0,1)

        if torch.cuda.is_available():
            batch = [b.to(torch.device("cuda")) for b in batch]

        pad_token_id = data.tokenizer.pad_token_id

        batch[0], batch[1] = trim_batch(batch[0], pad_token_id, batch[1])
        batch[2], batch[3] = trim_batch(batch[2], pad_token_id, batch[3])

        with torch.no_grad():
            loss = model(input_ids=batch[0], attention_mask=batch[1], 
                        decoder_input_ids=batch[2], decoder_attention_mask=batch[3],
                        block_distribution=enc_routes,
                        decoder_block_distribution=dec_routes,
                        is_training=True)

        total_target_tokens += torch.sum(torch.sum(batch[3])).item()
        total_loss += loss.detach().cpu().item()

    return total_loss / total_target_tokens

def train_task_emb(args, logger, task_emb, model, train_data):
    # create a trainable task_emb
    task_emb = task_emb.clone().requires_grad_()
    task_emb = torch.nn.Parameter(task_emb)

    optimizer = SGD([task_emb], lr=args.learning_rate)
    model.set_gumbel_temperature(args.initial_tau)

    global_step = 0
    stop_training = False

    model.train()

    for epoch in range(1000000):
        for batch in train_data.dataloader:
            global_step += 1

            if torch.cuda.is_available():
                batch = [b.to(torch.device("cuda")) for b in batch]
            
            pad_token_id = train_data.tokenizer.pad_token_id

            batch[0], batch[1] = trim_batch(batch[0], pad_token_id, batch[1])
            batch[2], batch[3] = trim_batch(batch[2], pad_token_id, batch[3])

            bsz = batch[0].shape[0]
            task_embeds = task_emb.expand(bsz, -1)

            loss = model(input_ids=batch[0], attention_mask=batch[1],
                         decoder_input_ids=batch[2], decoder_attention_mask=batch[3],
                         task_embed=task_embeds,
                         is_training=True)

            total_target_tokens = torch.sum(torch.sum(batch[3])).item()
            # logger.info("Loss at step {}: {}".format(global_step, loss.detach().cpu().item() / total_target_tokens))

            # print(task_emb)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            tau = get_gumbel_temperature(args, global_step)
            model.set_gumbel_temperature(tau)

            if global_step % args.eval_period == 0 :
                model.eval()
                enc_routes0, dec_routes0 = model.get_routes(task_emb, separate=True)
                loss = get_loss(model, enc_routes0, dec_routes0, train_data)
                logger.info("Loss at step {}: {}".format(global_step, loss))
                model.train()

            if global_step == args.total_steps:
                stop_training = True
                break

        if stop_training:
            break

    model.set_gumbel_temperature(args.minimum_tau)
    return task_emb

def inference(model, dev_data, enc_route, dec_route, save_predictions=False, verbose=False):
    predictions = []
    bos_token_id = dev_data.tokenizer.bos_token_id
    for i, batch in enumerate(dev_data.dataloader):
        if torch.cuda.is_available():
            batch = [b.to(torch.device("cuda")) for b in batch]
        pad_token_id = dev_data.tokenizer.pad_token_id
        batch[0], batch[1] = trim_batch(batch[0], pad_token_id, batch[1])
        bsz = batch[0].shape[0]
        enc_routes = enc_route.expand(bsz, -1, -1).transpose(0,1)
        dec_routes = dec_route.expand(bsz, -1, -1).transpose(0,1)

        outputs = model.generate(input_ids=batch[0],
                                 attention_mask=batch[1],
                                 block_distribution=enc_routes,
                                 decoder_block_distribution=dec_routes,
                                 num_beams=dev_data.args.num_beams,
                                 max_length=dev_data.args.max_output_length,
                                 decoder_start_token_id=model.config.bos_token_id,
                                 early_stopping=dev_data.gen_early_stop,
                                 use_cache=False,)

        for input_, output in zip(batch[0], outputs):
            pred = dev_data.decode(output)
            predictions.append(pred)
    if save_predictions:
        dev_data.save_predictions(predictions)
    return dev_data.evaluate(predictions, verbose=verbose)

def finetune_and_get_loss(args, logger, model, enc_routes0, dec_routes0, train_data, dev_data):

    enc_routes0 = enc_routes0.detach()
    dec_routes0 = dec_routes0.detach()

    sd = {k:v.cpu() for (k, v) in model.state_dict().items()}

    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': 3e-5},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': 3e-5}
        ]

    optimizer = AdamW(optimizer_grouped_parameters, eps=args.adam_epsilon)

    global_step = 0
    stop_training = False

    model.train()    
    for epoch in range(1000000):
        for batch in train_data.dataloader:
            global_step += 1

            if torch.cuda.is_available():
                batch = [b.to(torch.device("cuda")) for b in batch]
            
            pad_token_id = train_data.tokenizer.pad_token_id

            batch[0], batch[1] = trim_batch(batch[0], pad_token_id, batch[1])
            batch[2], batch[3] = trim_batch(batch[2], pad_token_id, batch[3])

            bsz = batch[0].shape[0]
            enc_routes = enc_routes0.expand(bsz, -1, -1).transpose(0,1)
            dec_routes = dec_routes0.expand(bsz, -1, -1).transpose(0,1)

            loss = model(input_ids=batch[0], attention_mask=batch[1],
                         decoder_input_ids=batch[2], decoder_attention_mask=batch[3],
                         block_distribution=enc_routes,
                         decoder_block_distribution=dec_routes,
                         is_training=True)

            # logger.info("Loss at step {}: {}".format(global_step, loss.detach().cpu().item() / total_target_tokens))

            # print(task_emb)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            optimizer.step()
            optimizer.zero_grad()

            if global_step % 1 == 0 :
                model.eval()
                loss = get_loss(model, enc_routes0, dec_routes0, train_data)
                # logger.info("Loss at step {}: {}".format(global_step, loss))
                model.train()

            if global_step == 10:
                stop_training = True
                break

        if stop_training:
            break

    train_loss = get_loss(model, enc_routes0, dec_routes0, train_data)
    dev_loss = get_loss(model, enc_routes0, dec_routes0, dev_data)
    dev_perf = inference(model, dev_data, enc_route=dec_routes0, dec_route=dec_routes0)

    model.load_state_dict(sd)

    return train_loss, dev_loss, dev_perf

def route_analysis(args, logger, task_model, model, task_id, train_data, dev_data):
    df = pd.DataFrame(columns=["trial_id", "train_loss", "dev_loss", "dev_perf", "note"])

    logger.info("running route analysis")

    logger.info("Part 1: use untrained task representation")
    task_emb = task_model(task_id)
    enc_routes0, dec_routes0 = model.get_routes(task_emb, separate=True)

    train_loss = get_loss(model, enc_routes0, dec_routes0, train_data)
    dev_loss = get_loss(model, enc_routes0, dec_routes0, dev_data)
    dev_perf = inference(model, dev_data, enc_route=dec_routes0, dec_route=dec_routes0)

    df.loc[len(df.index)] = ["1", train_loss, dev_loss, dev_perf, "untrained"]
    logger.info("Part 1: Loss = {} | {} | {}".format(train_loss, dev_loss, dev_perf))

    # logger.info("Part 2: use trained task representation")
    # new_task_emb = train_task_emb(args, logger, task_emb, model, train_data)
    # enc_routes0, dec_routes0 = model.get_routes(new_task_emb, separate=True)

    # train_loss = get_loss(model, enc_routes0, dec_routes0, train_data)
    # dev_loss = get_loss(model, enc_routes0, dec_routes0, dev_data)
    # dev_perf = inference(model, dev_data, enc_route=dec_routes0, dec_route=dec_routes0)

    # df.loc[len(df.index)] = ["2", train_loss, dev_loss, "trained"]
    # logger.info("Part 2: Loss = {} | {} | {}".format(train_loss, dev_loss, dev_perf))

    logger.info("Part 3: use routes of seen tasks")
    all_train_tasks = get_tasks_list(args.custom_tasks_splits, "train")

    for i, taskname in enumerate(all_train_tasks):
        task_id = task_model.taskname2id(taskname)
        if torch.cuda.is_available():
            task_id = task_id.to(torch.device("cuda"))
        task_emb = task_model(task_id)
        enc_routes0, dec_routes0 = model.get_routes(task_emb, separate=True)  

        train_loss = get_loss(model, enc_routes0, dec_routes0, train_data)
        dev_loss = get_loss(model, enc_routes0, dec_routes0, dev_data)
        dev_perf = inference(model, dev_data, enc_route=dec_routes0, dec_route=dec_routes0)

        df.loc[len(df.index)] = ["3_{}_wo".format(i), train_loss, dev_loss, dev_perf, taskname]
     
        logger.info("Part 3 Trial {} (no ft): Loss = {} | {} | {}".format(i, train_loss, dev_loss, dev_perf))

        train_loss, dev_loss, dev_perf = finetune_and_get_loss(
            args, logger, model, enc_routes0, dec_routes0, train_data, dev_data
        )

        logger.info("Part 3 Trial {} (w/ ft): Loss = {} | {} | {}".format(i, train_loss, dev_loss, dev_perf))
        df.loc[len(df.index)] = ["3_{}_w".format(i), train_loss, dev_loss, dev_perf, taskname]


    logger.info("Part 4: use random routes")
    for i in range(64):

        enc_routes0 = torch.nn.functional.one_hot(torch.rand(6,3).argmax(dim=1),3).float()
        dec_routes0 = torch.nn.functional.one_hot(torch.rand(6,3).argmax(dim=1),3).float()

        if torch.cuda.is_available():
            enc_routes0 = enc_routes0.to(torch.device("cuda"))
            dec_routes0 = dec_routes0.to(torch.device("cuda"))

        train_loss = get_loss(model, enc_routes0, dec_routes0, train_data)
        dev_loss = get_loss(model, enc_routes0, dec_routes0, dev_data)
        dev_perf = inference(model, dev_data, enc_route=dec_routes0, dec_route=dec_routes0)

        df.loc[len(df.index)] = ["4_{}_wo".format(i), train_loss, dev_loss, dev_perf, "random_routes"]

        logger.info("Part 4 Trial {} (no ft): Loss = {} | {} | {}".format(i, train_loss, dev_loss, dev_perf))

        train_loss, dev_loss, dev_perf = finetune_and_get_loss(
            args, logger, model, enc_routes0, dec_routes0, train_data, dev_data
        )

        logger.info("Part 4 Trial {} (w/ ft): Loss = {} | {} | {}".format(i, train_loss, dev_loss, dev_perf))
        df.loc[len(df.index)] = ["4_{}_w".format(i), train_loss, dev_loss, dev_perf, "random_routes"]


    return df