from tkinter import font
from transformers import AutoTokenizer, AutoModel
from hashlib import blake2b
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

import torch
import argparse
import sentencepiece as spm

from tqdm import tqdm
from datasets import load_dataset

def encrypt_token_using_blake(key, token, digest_size=16):
    key = bytes(key, "utf-8")
    h = blake2b(key=key, digest_size=16)
    token = bytes(token, "utf-8")
    h.update(token)
    return h.hexdigest()

def ednn(special_token_ids, pretrained_weights, enc_fine_tuned_wieghts, index_to_vocab, enc_index_to_vocab, key):
    """ Normalize pretrained embeddings if the encrypted embeddings are normalized """
    # pretrained_weights = pretrained_weights / np.expand_dims(np.linalg.norm(pretrained_weights, axis=1), axis=1)
    # enc_fine_tuned_wieghts = enc_fine_tuned_wieghts / np.expand_dims(np.linalg.norm(enc_fine_tuned_wieghts, axis=1), axis=1)

    pretrained_shifted = torch.roll(pretrained_weights, 1, dims=1)
    pretrained_wegith_diff = pretrained_weights - pretrained_shifted
    
    enc_fine_tuned_shifted = torch.roll(enc_fine_tuned_wieghts, 1, dims=1)
    enc_fine_tuned_diff = enc_fine_tuned_wieghts - enc_fine_tuned_shifted
    
    enc_to_plain_ids = {} # {enc_token_ids: plain_token_ids}
    enc_correct_map = {} # {enc_token_ids: bool}

    # NOTE: compute dist in blocks to avoid OOM for large embedding matrix
    all_dist = torch.cdist(pretrained_wegith_diff, enc_fine_tuned_diff, p=2)

    total = 0
    correct = 0
    print("special_token_ids: ", special_token_ids)
    for i in tqdm(range(pretrained_wegith_diff.shape[0])):
        if i in special_token_ids:
            continue
        total += 1

        # dist = torch.cdist(pretrained_wegith_diff[[i], :], enc_fine_tuned_diff, p=2)
        dist = all_dist[[i], :]
        topk_dist, top_indices = torch.topk(dist, 2, largest=False, dim=1, sorted=True)
        top_indices = top_indices.cpu().numpy()[0]
        topk_dist = topk_dist.cpu().numpy()[0]
        enc_correct_map[i] = False
        enc_token = enc_index_to_vocab[i]
        enc_to_plain_ids[i] = top_indices[0]
        pretrained_token = index_to_vocab[top_indices[0]]
        if enc_token == encrypt_token_using_blake(key, pretrained_token):
            correct += 1
            enc_correct_map[i] = True
            # print(f"Index: {index_to_vocab[i]} -> {top_indices[i]}, Token: {pretrained_token} -> {enc_token}, encrypt pretrained_token = {one_enc}, Accuracy = {correct/ total}, Privacy = {1 - correct/ total} ", flush=True)
    print(f"EDNN Accuracy = {correct/ total}, Privacy = {1 - correct/ total} ")
    return enc_to_plain_ids, enc_correct_map

def embedding_analysis(special_token_ids, pretrained_weights, enc_fine_tuned_wieghts, index_to_vocab, enc_index_to_vocab, key):
    """ Normalize pretrained embeddings if the encrypted embeddings are normalized """
    # pretrained_weights = pretrained_weights / np.expand_dims(np.linalg.norm(pretrained_weights, axis=1), axis=1)
    # enc_fine_tuned_wieghts = enc_fine_tuned_wieghts / np.expand_dims(np.linalg.norm(enc_fine_tuned_wieghts, axis=1), axis=1)

    pretrained_shifted = torch.roll(pretrained_weights, 1, dims=1)
    pretrained_wegith_diff = pretrained_weights - pretrained_shifted
    
    enc_fine_tuned_shifted = torch.roll(enc_fine_tuned_wieghts, 1, dims=1)
    enc_fine_tuned_diff = enc_fine_tuned_wieghts - enc_fine_tuned_shifted
    
    dist_to_correct = []
    dist_to_other = []

    # NOTE: compute dist in blocks to avoid OOM for large embedding matrix
    all_dist = torch.cdist(enc_fine_tuned_diff, pretrained_wegith_diff, p=2)

    enc_vocab_to_index = {token: index for index, token in enc_index_to_vocab.items()}
    enc_index_to_plain_index = {}
    print("special_token_ids: ", special_token_ids)
    for plain_idx, plain_token in index_to_vocab.items():
        if plain_idx in special_token_ids:
            continue
        # print("plain_token: ", plain_token)
        enc_index = enc_vocab_to_index[encrypt_token_using_blake(key, plain_token)]
        enc_index_to_plain_index[enc_index] = plain_idx

    for i in tqdm(range(enc_fine_tuned_diff.shape[0])):
        if i in special_token_ids:
            continue
        # dist = torch.cdist(enc_fine_tuned_diff[[i], :], pretrained_wegith_diff, p=2)[0]
        dist = all_dist[i, :]
        dist_to_correct.append(dist[enc_index_to_plain_index[i]].cpu().item()) 
        dist[enc_index_to_plain_index[i]] = 0xFFFFFFFF
        dist_to_other.append(torch.min(dist).cpu().item())

    sns.set(style="whitegrid")
    plt.figure(figsize=(10, 5))
    dist_to_correct = np.array(dist_to_correct)
    print(np.average(dist_to_correct))
    dist_to_other = np.array(dist_to_other)
    print(np.average(dist_to_other))
    sns.histplot(dist_to_correct, bins=30, color="blue", alpha=0.7, log_scale=True, label="Distance to correct token", stat='density')
    sns.histplot(dist_to_other, bins=30, color="red", alpha=0.7, log_scale=True, label="Distance to other token", stat='density')
    plt.tick_params(axis='both', which='major', labelsize=14)
    plt.xlabel('Distance of element-wise differentials', fontsize=16)
    plt.ylabel('Density', fontsize=14)

    plt.legend()
    plt.show()
    plt.savefig('embedding_analysis.png')

    return dist_to_correct, dist_to_other

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Argument parser for model training.")

    # Model name or path argument (mandatory)
    parser.add_argument("--enc_model_path", type=str, help="Name or path of the encrypted (finetuned) model")
    parser.add_argument("--pretrained_model_path", type=str, help="Name or path of the public (pretrained) model")
    parser.add_argument("--key", default="languagemodel123", type=str, help="Encryption key used to check whether attack success")
    args = parser.parse_args()

    enc_model_path = args.enc_model_path
    pretrained_model_path = args.pretrained_model_path
    key = args.key

    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_path)
    vocabulary = tokenizer.get_vocab()
    index_to_vocab = {idx: token for token, idx in vocabulary.items()}

    enc_tokenizer = AutoTokenizer.from_pretrained(enc_model_path)
    enc_vocabulary = enc_tokenizer.get_vocab()
    enc_index_to_vocab = {idx: token for token, idx in enc_vocabulary.items()}

    pretrained_model = AutoModel.from_pretrained(pretrained_model_path)
    enc_model = AutoModel.from_pretrained(enc_model_path)

    pretrained_embedding_weights = pretrained_model.get_input_embeddings().weight.detach()
    enc_embedding_weights = enc_model.get_input_embeddings().weight.detach()

    special_token_ids = {tok:tokenizer.convert_tokens_to_ids(tok) for tok in tokenizer.special_tokens_map.values()}
    special_token_ids = list(special_token_ids.values())
    del pretrained_model
    del enc_model

    # print("running EDNN attack")
    # ednn(special_token_ids, pretrained_embedding_weights, enc_embedding_weights, index_to_vocab, enc_index_to_vocab, key)

    print("running embedding analysis")
    embedding_analysis(special_token_ids, pretrained_embedding_weights, enc_embedding_weights, index_to_vocab, enc_index_to_vocab, key)
