import torch

import transformers
from transformers import AutoModelForMaskedLM, AutoTokenizer

from collections import OrderedDict

def id_to_token(tokenizer: AutoTokenizer, id: int):
    return id2token[id]

def token_to_id(tokenizer: AutoTokenizer, token: str):
    return token2id[token]

def embedding_to_embedding(embedding_matrix):
    """
    full vocab size
    """
    ret = torch.matmul(embedding_matrix, embedding_matrix.transpose(0, 1)) + bias.unsqueeze(-1)
    ret = torch.softmax(ret, dim=0)
    return ret

def embedding_to_embedding_w_rank(u, s, rank):
    """
    reduce dimension from vocab size to the real rank
    """
    ret = torch.matmul(u[:, :rank], torch.matmul(s, u.transpose(0, 1))[:rank, :])
    ret = torch.softmax(ret, dim=0)
    return ret

if __name__ == "__main__":
    svd = torch.svd

    model_name = "roberta-large"

    # tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    token2id = tokenizer.vocab
    id2token = {token2id[token]:token for token in token2id}

    # obtain module
    model_with_mlm = AutoModelForMaskedLM.from_pretrained(model_name, return_dict=True)
    if "roberta" in model_name:
        mlm_head = model_with_mlm.lm_head.decoder
    else:
        mlm_head = model_with_mlm.cls.predictions.decoder
    embedding_matrix = mlm_head.weight.data
    bias = mlm_head.bias.data
    rank = torch.matrix_rank(embedding_matrix).item()

    vocab_to_vocab = embedding_to_embedding(embedding_matrix=embedding_matrix)
    u, s, v = svd(embedding_matrix, some=False)
    s_2 = s*s
    s_2 = torch.cat([s_2, torch.zeros(embedding_matrix.shape[0]-rank)], dim=0)
    eigen_matrix = torch.diag(s_2)
    # vocab_to_vocab_reduced = embedding_to_embedding_w_rank(u, eigen_matrix, rank)
    token2dist = OrderedDict()
    for id in range(0, tokenizer.vocab_size):
        token = id_to_token(tokenizer, id)
        _, indices = torch.topk(vocab_to_vocab[:, id], k=10)
        topk_tokens = [id_to_token(tokenizer, id.item()) for id in indices]
        tmp = []
        for token_ in topk_tokens:
            if token_[0] == 'Ġ':
                tmp.append(token_[1:])
            else:
                tmp.append(token_)
        if token[0] == 'Ġ':
            token = token[1:]
        token2dist[token] = tmp
    with open(f"vocab_to_vocab_{model_name}.txt", 'w') as f:
        for token in token2dist:
            f.write("{}: {}\n".format(token, token2dist[token]))
    print('done')
    exit()
    for noun in ['see', 'eat', 'speak', 'walk']:
        id = token_to_id(tokenizer, noun)
        _, indices = torch.topk(vocab_to_vocab[:,id], k=10)
        topk_tokens = [id_to_token(tokenizer, id) for id in indices]
        print(topk_tokens)
    print("******************************************")
    for noun in ['cat', 'dog', 'house', 'teacher', 'doctor', 'football']:
        id = token_to_id(tokenizer, noun)
        _, indices = torch.topk(vocab_to_vocab[:,id], k=10)
        topk_tokens = [id_to_token(tokenizer, id) for id in indices]
        print(topk_tokens)
    print("******************************************")
    
    _, indices = torch.topk(bias, k=10)
    topk_tokens = [id_to_token(tokenizer, id) for id in indices]
    print(topk_tokens)
    # id = token_to_id(tokenizer, 'teach')
    # _, indices = torch.topk(vocab_to_vocab[:,id], k=5)
    # topk_tokens = [id_to_token(tokenizer, id) for id in indices]
    # print(topk_tokens)
    # id = token_to_id(tokenizer, 'walk')
    # _, indices = torch.topk(vocab_to_vocab_reduced[:,id], k=5)
    # topk_tokens = [id_to_token(tokenizer, id) for id in indices]
    # print(topk_tokens)