# PYTHONPATH='.' python playground/routing_bart/load_routing_bart.py 

import torch
import numpy as np

from modules.routing_bart_config import RoutingBartConfig
from modules.routing_bart_v2 import MyRoutingBart
from modules.bart import MyBart
from modules.utils import initialize_weights, squeeze_weights

from transformers import BartForConditionalGeneration, BartTokenizer

def main():
    config = RoutingBartConfig.from_pretrained("facebook/bart-base")
    config.router_soft_select = True
    config.router_gumbel = False
    # print(config)

    model = MyRoutingBart(config)
    # print(model)

    task_embed = torch.tensor([1.0]* config.router_input_dim).unsqueeze(0)

    input_ids = torch.tensor([[1,2,3],[2,3,4]])
    attention_mask = torch.tensor([[1,1,1],[1,1,1]])

    decoder_input_ids = torch.tensor([[4,5,6],[7,8,9]])
    decoder_attention_mask = torch.tensor([[1,1,1],[1,1,1]])

    loss = model(input_ids=input_ids, attention_mask=attention_mask, 
        decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask,
        is_training=True, task_embed=task_embed)

    print(loss)

    loss.backward()

    # for n, p in model.named_parameters():
    #     if "router" in n:
    #         print(n)
    #         print(p.grad)

def main2():

    config = RoutingBartConfig.from_pretrained("facebook/bart-base")
    model = MyRoutingBart(config)
    model_old = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
    initialize_weights(config, model, model_old)
    tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")

    input_ids = [[    0,   646,  1043,  2839,  8307,  1215,  8009,  5000,   742,    96,                                                                                           
             5,  3034,  2866, 13144,  2156,   215,    10,  6108,    16,   373,                                                                                            
            10,  1950,  4202,  7436, 48866, 36861, 46194,  3552,    36,   274,                                                                                            
          4454,  2336,  4839,    13,  1546, 33572, 34337,   479,   646,  3388,                                                                                            
           510,   742, 29703,    35,   274,  4454,  2336,     2,     1,     1,                                                                                            
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,                                                                                            
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,                                                                                            
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,                                                                                            
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,                                                                                            
             1,     1,     1],                                                                                                                                            
        [    0,   646,  1043,  2839,  8307,  1215,  8009,  5000,   742,   166,                                                                                            
         15393,    10, 11445, 33344,    36,  3957,  4839,   121,   111,  5008,                                                                                            
             7,  1477, 10272, 19673, 41722,   479,   646,  3388,   510,   742,                                                                                            
         29703,    35,  3957,     2,     1,     1,     1,     1,     1,     1,                                                                                            
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,                                                                                            
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,                                                                                            
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,                                                                                            
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,                                                                                            
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,                                                                                            
             1,     1,     1],                                                                                                                                            
        [    0,   646,  1043,  2839,  8307,  1215,  8009,  5000,   742,    20,                                                                                            
          3854,     9,     5,   674,  1451, 38369,    36,  6266,   495,  4839,                                                                                            
          3266,    13,    70, 15029,    16, 38152,    13,   349,     9,     5,                                                                                            
           292, 15775,   710, 28470, 12712,    11,     5,  2233,  2911,  5992,                                                                                            
          1640,   170,   304,  2526,  2233,  2911,  5992,  4832,     5,  1275,                                                                                            
          2003,  8711,     5,  9640,  2156,     5,  2440,  2233,  2029,     5,                                                                                            
             8,   135,  4755,    36,     8,  4839,  2156,     5,   909,  5692,                                                                                            
         20719,     7,     8,  2156,     8,     5,  1275, 20238,    32,   373,                                                                                            
         31187,  4733,   479,  4839,   646,  3388,   510,   742, 29703,    35,                                                                                            
          6266,   495,     2],                                                                                                                                            
        [    0,   646,  1043,  2839,  8307,  1215,  8009,  5000,   742, 31967,                                                                                            
          4230,  4412,  1639,   130,  6134,     9,  6363,  4832, 14929,   111,                                                                                            
           716,  2156, 44796,   111,   716,     8, 14224, 21218,   479,   646,                                                                                            
          3388,   510,   742, 29703,    35, 21218,     2,     1,     1,     1,                                                                                            
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,                                                                                            
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,                                                                                            
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,                                                                                            
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,                                                                                            
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,                                                                                            
             1,     1,     1]]
    
    decoder_input_ids = [[    0,  1950,  4202,  7436, 48866, 36861, 46194,  3552,     2],
        [    0, 11445, 33344,     2,     1,     1,     1,     1,     1],
        [    0,   674,  1451, 38369,     2,     1,     1,     1,     1],
        [    0, 31967,  4230,     2,     1,     1,     1,     1,     1]]

    task_embed = torch.rand(4, 768).cuda()
    print(task_embed.shape)

    input_ids = torch.tensor(input_ids).cuda()
    decoder_input_ids = torch.tensor(decoder_input_ids).cuda()

    model.cuda()

    # print(model.final_logits_bias)

    model.eval()
    # loss = model(input_ids=input_ids, attention_mask=input_ids.ne(1), 
    #             decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_input_ids.ne(1),
    #             task_embed=task_embed, is_training=True)

    task_embed = torch.rand(1, 768).cuda()
    enc_routes0, dec_routes0 = model.get_routes(task_embed, separate=True)
    print(enc_routes0.shape)

    enc_routes0 = enc_routes0.squeeze(1).expand(4, -1, -1).transpose(0,1)
    dec_routes0 = dec_routes0.squeeze(1).expand(4, -1, -1).transpose(0,1)

    tokens = model.generate(input_ids=input_ids, attention_mask=input_ids.ne(1), 
                block_distribution=enc_routes0, decoder_block_distribution=dec_routes0,
                num_beams=4, use_cache=True, use_sparse=True)

    for token in tokens:
        g = tokenizer.decode(token, skip_special_tokens=True, clean_up_tokenization_spaces=True)
        print(g)


def main3():

    task_vecs_filename = "data/taskvec/dummy/task_vecs.npy"
    task_vecs = torch.from_numpy(np.load(task_vecs_filename))

    n_tasks, dim = task_vecs.shape

    task_embed = torch.nn.Embedding(n_tasks, dim)
    task_embed.load_state_dict({"weight": task_vecs})
    print(task_embed)

def main4():
    config = RoutingBartConfig.from_pretrained("facebook/bart-base")
    model_old = MyRoutingBart(config)
    model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
    enc_routes = torch.nn.functional.one_hot(torch.rand(6,3).argmax(dim=1),3).float()
    dec_routes = torch.nn.functional.one_hot(torch.rand(6,3).argmax(dim=1),3).float()
    print(enc_routes)
    squeeze_weights(config, model, model_old, enc_routes, dec_routes)


if __name__ == "__main__":
    main4()