import os
import torch

from transformers import BartTokenizer, BartConfig

from modules.routing_bart_config import RoutingBartConfig
from modules.routing_bart_v2 import MyRoutingBart
from modules.task2vec import Task2Vec

from run_singletask import load_moe_model, load_predict_data
from utils import trim_batch

def run(args, logger):
    tokenizer = BartTokenizer.from_pretrained(args.model)
    eval_data = load_predict_data(args, logger, tokenizer)
    config, model, task_model = load_moe_model(args, logger)

    model.eval()
    task_model.eval()
    model.set_gumbel_temperature(0.1)
    model.set_router_mode(config.router_mode)

    test_performance = predict(args, logger, model, task_model, eval_data)
    print(test_performance)

def string2mask(disabling_string):
    assert len(disabling_string) == 36
    l = torch.tensor([int(c) for c in disabling_string])
    l = l.reshape(12, 3).unsqueeze(1)
    # print(l)
    if torch.cuda.is_available():
        l = l.to(torch.device("cuda"))
    return l[:6], l[6:]

def load_moe_model(args, logger):

    # load config, model and task_model
    config_path = os.path.join(args.init_dir, "config.json")
    config = RoutingBartConfig.from_pretrained(config_path)

    model = MyRoutingBart(config)
    model_path = os.path.join(args.init_dir, args.checkpoint_name, "model.pt")
    model.load_state_dict(torch.load(model_path))

    task_model_path = os.path.join(args.init_dir, args.checkpoint_name)
    task_model = Task2Vec(task_model_path)

    if torch.cuda.is_available():
        model.to(torch.device("cuda"))
        task_model.to(torch.device("cuda"))

    return config, model, task_model
    
def predict(args, logger, model, task_model, data):
    enc_route_mask, dec_route_mask = string2mask(args.expert_disabling)

    with torch.no_grad():
        task_id = task_model.taskname2id(data.task_name)
        if torch.cuda.is_available():
            task_id = task_id.to(torch.device("cuda"))
        task_emb = task_model(task_id)
        enc_routes0, dec_routes0 = model.get_routes(
            task_emb, separate=True, 
            override=(enc_route_mask, dec_route_mask)
        )

    predictions = []
    bos_token_id = data.tokenizer.bos_token_id
    for i, batch in enumerate(data.dataloader):
        if torch.cuda.is_available():
            batch = [b.to(torch.device("cuda")) for b in batch]
        pad_token_id = data.tokenizer.pad_token_id
        batch[0], batch[1] = trim_batch(batch[0], pad_token_id, batch[1])
        bsz = batch[0].shape[0]
        enc_routes = enc_routes0.expand(bsz, -1, -1).transpose(0,1)
        dec_routes = dec_routes0.expand(bsz, -1, -1).transpose(0,1)

        outputs = model.generate(input_ids=batch[0],
                                attention_mask=batch[1],
                                block_distribution=enc_routes,
                                decoder_block_distribution=dec_routes,
                                num_beams=data.args.num_beams,
                                max_length=data.args.max_output_length,
                                decoder_start_token_id=model.config.bos_token_id,
                                early_stopping=data.gen_early_stop,
                                use_cache=True,
                                use_sparse=True,
                                )

        for input_, output in zip(batch[0], outputs):
            pred = data.decode(output)
            predictions.append(pred)

    data.save_predictions(predictions)
    return data.evaluate(predictions)