'''
Date: 2021-06-10 09:42:07
LastEditors: Wu Xianze (wuxianze.0@bytedance.com)
LastEditTime: 2021-06-30 21:31:33
'''
import os
import json
import argparse
import seaborn as sns
import numpy as np
from MulticoreTSNE import MulticoreTSNE as TSNE
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
# from normalizeEmb import json2array, array2json

def discard_small_values(embedding, threshold):
    """
    discard the value less than threshold
    Input
        embedding: np.array or List
    Return:
        normalized_embedding: np.array
    """
    is_list = False
    if isinstance(embedding, list):
        is_list = True
        embedding = np.array(embedding)

    embedding = np.where(
        embedding < threshold,
        0,
        embedding
    )
    if is_list:
        embedding = embedding.tolist()
    return embedding

def label_points(xs, ys, vals, ax):
    for (x, y, val) in zip(xs, ys, vals):
        ax.text(x + .02, y, str(val), fontdict={'size': 7})

def readJsonl(fname):
    datas = []
    with open(fname, 'r') as fin:
        for line in fin:
            datas.append(json.loads(line.strip()))
    return datas

def createHashMap(datas: list, key: str):
    results = {}
    for data in datas:
        newkey = data[key].strip().lower()
        newkey = normalize(newkey)
        results[newkey] = data
    return results

def findNN(query: str, keys: list):
    max_f1, nn = 0.0, ""
    qtokens = query.split()
    qtokens = set(qtokens)
    for key in keys:
        ktokens = set(key.split())
        overlap = qtokens.intersection(ktokens)
        precision = len(overlap) / (len(ktokens) + 1e-3)
        recall = len(overlap) / (len(qtokens) + 1e-3)
        f1 = 2 * precision * recall / (precision + recall + 1e-3)
        if f1 > max_f1:
            max_f1 = f1
            nn = key
    return nn

def normalize(s: str):
    s = s.replace("\\", "")
    s = s.replace("\"", "")
    s = s.replace(".", "")
    s = s.strip().lower()
    s = " ".join(s.split())
    return s

def matchEmb(input_dir, embkey, datakey, lg='en'):
    """
    return the lsit of sent embedding info
    """
    emb_infos = readJsonl(os.path.join(input_dir, embkey))
    data_infos = readJsonl(os.path.join(input_dir, datakey))
    hashed_emb_info = createHashMap(emb_infos, "sent")

    missed_num = 0
    missed_sents = []
    all_num = 0
    matched_infos = []
    for dinfo in data_infos[5:6]:
        sents = dinfo['document']
        doc_id = dinfo['id']
        oracles = dinfo['label']
        all_num += len(sents)
        for (i, sent) in enumerate(sents):
            sent = sent.strip().lower()
            sent = normalize(sent)
            if sent != "":
                emb_info = hashed_emb_info.get(sent, None)
                if emb_info is None:
                    missed_num += 1
                    missed_sents.append([sent, i])
                else:
                    emb_info['id'] = "{}_{}".format(doc_id, i)
                    emb_info['is_oracle'] = (i in oracles)
                    emb_info['lg'] = lg
                    matched_infos.append(emb_info)
    return matched_infos

def visualizeDocEmbPair(args):
    if not os.path.exists(args.o):
        os.makedirs(args.o)

    system_files = args.system_files
    all_infos = {}
    sys_names = []
    for input_path in system_files:
        infos = readJsonl(input_path)
        name = input_path.split('/')[-2]
        sys_names.append(name)
        all_infos[name] = infos

    # calculate sentence embedding
    system_embs = dict()
    for sys_name in all_infos:
        emb_list = []
        for emb_info in all_infos[sys_name][:args.topk]:
            emb_list.append(emb_info["embedding"])
        system_embs[sys_name] = emb_list

    distances = []
    namea = sys_names[0]
    nameb = sys_names[1]
    if args.metric == "cka":
        # avg_distance = ckaDistance(np.array(system_embs[namea]), np.array(system_embs[nameb]))
        avg_distance = linear_CKA(np.array(system_embs[namea]), np.array(system_embs[nameb]))
    else:
        for (a, b) in zip(system_embs[namea], system_embs[nameb]):
            distances.append(distanceFn(a, b, distance_type=args.metric))

        avg_distance = np.mean(distances)
    print("{} distance between {} and {} is {:.3f}".format(
        args.metric, namea, nameb, avg_distance
    ))

    # visualize
    emb_list = []
    names = []
    for sys_name in all_infos:
        for emb_info in all_infos[sys_name][:args.topk]:
            emb_list.append(emb_info["embedding"])
            names.append(sys_name)

    embedding_array = np.array(emb_list)
    model = TSNE(n_components=2, n_jobs=30, n_iter=1000, verbose=1)
    node_pos = model.fit_transform(embedding_array)
    x_pos = []
    y_pos = []
    for (_, item) in enumerate(node_pos):
        x_pos.append(item[0])
        y_pos.append(item[1])
    
    selected_x, selected_y, selected_i, selected_names = [], [], [], []
    for i, (xp, yp) in enumerate(zip(x_pos, y_pos)):
        selected_x.append(xp)
        selected_y.append(yp)
        selected_names.append(names[i])

    plt.xlim(-40, 40)
    plt.ylim(-40, 40)
    sns.scatterplot(x=selected_x, y=selected_y, hue=selected_names)
    graph_path = os.path.join(args.o, "{}_{}_{:.3f}.png".format(args.n, args.metric, avg_distance))
    plt.savefig(graph_path, dpi=200)
    print("save an image to {}".format(graph_path))
    plt.clf()

def visualizeDocEmbPairPCA(args):
    if not os.path.exists(args.o):
        os.makedirs(args.o)

    system_files = args.system_files
    # print(system_files)
    all_infos = {}
    sys_names = []
    for input_path in system_files:
        infos = readJsonl(input_path)
        name = input_path.split('/')[-2]
        sys_names.append(name)
        all_infos[name] = infos

    # calculate sentence embedding
    system_embs = dict()
    for sys_name in all_infos:
        emb_list = []
        for emb_info in all_infos[sys_name][:args.topk]:
            embedding = emb_info["embedding"]
            normalized_embedding = discard_small_values(embedding, args.t)
            emb_list.append(normalized_embedding)
        system_embs[sys_name] = emb_list

    distances = []
    namea = sys_names[0]
    nameb = sys_names[1]
    if args.metric == "cka":
        avg_distance = linear_CKA(np.array(system_embs[namea]), np.array(system_embs[nameb]))
    else:
        for (a, b) in zip(system_embs[namea], system_embs[nameb]):
            distances.append(distanceFn(a, b, distance_type=args.metric))

        avg_distance = np.mean(distances)
    print("{} distance between {} and {} is {:.3f}".format(
        args.metric, namea, nameb, avg_distance
    ))

    # visualize
    emb_list = []
    names = []
    for sys_name in all_infos:
        for emb_info in all_infos[sys_name][:args.topk]:
            emb_list.append(emb_info["embedding"])
            names.append(sys_name)

    embedding_array = np.array(emb_list)

    pca = PCA(n_components=2).fit(embedding_array)
    node_pos = pca.transform(embedding_array)
    x_pos = []
    y_pos = []
    for (_, item) in enumerate(node_pos):
        x_pos.append(item[0])
        y_pos.append(item[1])
    
    selected_x, selected_y, selected_i, selected_names = [], [], [], []
    for i, (xp, yp) in enumerate(zip(x_pos, y_pos)):
        selected_x.append(xp)
        selected_y.append(yp)
        selected_names.append(names[i])

    sns.kdeplot(x=selected_x, y=selected_y, hue=selected_names)
    graph_path = os.path.join(args.o, "{}_{}_{:.3f}.png".format(args.n, args.metric, avg_distance))
    plt.savefig(graph_path, dpi=200)
    print("save an image to {}".format(graph_path))
    plt.clf()

def visualizeTokenEmbPair(args):
    if not os.path.exists(args.o):
        os.makedirs(args.o)

    system_files = args.system_files
    # print(system_files)
    all_infos = {}
    sys_names = []
    for input_path in system_files:
        infos = readJsonl(input_path)
        name = input_path.split('/')[-2]
        sys_names.append(name)
        all_infos[name] = infos

    # calculate sentence embedding
    system_embs = dict()
    for sys_name in all_infos:
        emb_list = []
        for emb_info in all_infos[sys_name][:args.topk]:
            for token_info in emb_info["embeddings"]:
                emb_list.append(token_info["embedding"])
        system_embs[sys_name] = emb_list

    distances = []
    namea = sys_names[0]
    nameb = sys_names[1]
    if args.metric == "cka":
        # avg_distance = ckaDistance(np.array(system_embs[namea]), np.array(system_embs[nameb]))
        avg_distance = linear_CKA(np.array(system_embs[namea]), np.array(system_embs[nameb]))
    else:
        for (a, b) in zip(system_embs[namea], system_embs[nameb]):
            distances.append(distanceFn(a, b, distance_type=args.metric))

        avg_distance = np.mean(distances)
    print("{} distance between {} and {} is {:.3f}".format(
        args.metric, namea, nameb, avg_distance
    ))
    avg_distance = 0

    # visualize
    emb_list = system_embs[namea] + system_embs[nameb]
    names = [namea for _ in range(len(system_embs[namea]))] + [nameb for _ in range(len(system_embs[nameb]))]

    embedding_array = np.array(emb_list)
    model = TSNE(n_components=2, n_jobs=30, n_iter=1000, verbose=1)
    node_pos = model.fit_transform(embedding_array)
    x_pos = []
    y_pos = []
    for (_, item) in enumerate(node_pos):
        x_pos.append(item[0])
        y_pos.append(item[1])
    
    selected_x, selected_y, selected_i, selected_names = [], [], [], []
    for i, (xp, yp) in enumerate(zip(x_pos, y_pos)):
        selected_x.append(xp)
        selected_y.append(yp)
        selected_names.append(names[i])

    sns.kdeplot(x=selected_x, y=selected_y, hue=selected_names)
    graph_path = os.path.join(args.o, "{}_{}_{:.3f}.png".format(args.n, args.metric, avg_distance))
    plt.savefig(graph_path, dpi=200)
    print("save an image to {}".format(graph_path))
    plt.clf()

def centering(K):
    n = K.shape[0]
    unit = np.ones([n, n])
    I = np.eye(n)
    H = I - unit / n

    return np.dot(np.dot(H, K), H)  # HKH are the same with KH, KH is the first centering, H(KH) do the second time, results are the sme with one time centering
    # return np.dot(H, K)  # KH

def linear_HSIC(X, Y):
    L_X = np.dot(X, X.T)
    L_Y = np.dot(Y, Y.T)
    return np.sum(centering(L_X) * centering(L_Y))

def linear_CKA(X, Y):
    hsic = linear_HSIC(X, Y)
    var1 = np.sqrt(linear_HSIC(X, X))
    var2 = np.sqrt(linear_HSIC(Y, Y))
    return hsic / (var1 * var2)

def distanceFn(emba, embb, distance_type="cosine"): 
    if isinstance(emba, list):
        emba = np.array(emba)
    if isinstance(embb, list):
        embb = np.array(embb)

    if distance_type == "euclidean":
        distance = np.sqrt(np.square(emba - embb).sum())
    elif distance_type == "cosine":
        norma = np.sqrt(np.sum(emba ** 2))
        normb = np.sqrt(np.sum(embb ** 2))
        distance = np.sum(emba * embb) / (norma * normb)
    elif distance_type == "cka":
        distance = linear_CKA(emba, embb)
    return distance

def embDistanceFn(all_infos: dict, namea: str, nameb: str, metric: str):
    infos_a = all_infos[namea]
    infos_b = all_infos[nameb]
    
    avg_distance = 0.0
    for (infoa, infob) in zip(infos_a, infos_b):
        emba = infoa['embedding']
        embb = infob['embedding']
        distance = distanceFn(emba, embb, distance_type=metric)
        avg_distance += distance
        
    return avg_distance / len(infos_a)

def embDistanceMatrixFn(embeddings1, embeddings2):
    """
    Input
        embeddings1: np.ndarray
        embeddings2: np.ndarray
    Return:
        distances: np.ndarray
    """
    distances = []
    for i in range(embeddings1.shape[0]):
        row = []
        for j in range(embeddings2.shape[0]):
            row.append(distanceFn(embeddings1[i], embeddings2[j], distance_type="cosine"))
        distances.append(row)
    # distance_matrix = np.matmul(embeddings1, embeddings2.T)
    return np.array(distances)

def matchAndCalEmbDistance(args):
    system_files = args.system_files

    all_infos = {}
    names = []
    for (i, input_path) in enumerate(system_files):
        infos = readJsonl(input_path)
        names.append(input_path.split('/')[-2])
        for (j, info) in enumerate(infos):
            embedding = info['embedding']
            normalized_embedding = discard_small_values(embedding, args.t)
            infos[j]['embedding'] = normalized_embedding
        
        all_infos[names[i]] = infos
    
    avg_distance = embDistanceFn(all_infos, names[0], names[1], metric=args.metric)
    print("average embedding distance between {} and {} is {:.3f}".format(
        names[0], names[1], avg_distance
    ))

# def matchAndCalEmbDiffScore(args):
#     system_files = args.system_files
#     all_infos = {}
#     names = []
#     for (i, input_path) in enumerate(system_files):
#         infos = readJsonl(input_path)[:args.topk]
#         embeddings = json2array(infos)
#         names.append(input_path.split('/')[-2])
#         all_infos[names[i]] = embeddings
    
#     embedding_distance_matrix = embDistanceMatrixFn(
#         all_infos[names[0]], 
#         all_infos[names[1]]
#     )
#     print(embedding_distance_matrix.shape)
#     eye_mask = np.eye(embedding_distance_matrix.shape[0])
#     diff_scores = (embedding_distance_matrix * eye_mask).sum(axis=-1) - embedding_distance_matrix.mean(axis=-1)
#     avg_score = diff_scores.mean()

#     print("average diff score {} and {} is {:.3f}".format(
#         names[0], names[1], avg_score
#     ))

# def visualizeEmbPair(args):
#     # assert len(args.dirs) == 1
#     # input_dir = args.dirs[0]
#     system_files = args.system_files
#     langs = args.langs

#     all_infos = []
#     for (i, input_path) in enumerate(system_files):
#         # input_path = os.path.join(input_dir, file_name)
#         system_name = input_path.split('/')[-1].split('.')[0]
#         infos = readJsonl(input_path)
#         for (j, info) in enumerate(infos):
#             infos[j]['lg'] = langs[i]
#         all_infos.extend(infos)
    
#     visualizeFn(
#         all_infos, embedding_key="embedding", 
#         hue_key='lg', text_key='doc_id',
#         graph_path=os.path.join(args.o, "graph.png")
#     )

def retrievalGivenArray(args):
    """
    retrieval embedding given two array
    We assume that the first embeddings are queries, the second embeddings are keys
    """
    assert len(args.system_files) == 2, "two systems is given for retrieving"
    for system in args.system_files:
        assert system.endswith("npy"), "we assume that inputs are np array"
    import torch
    np_embedding1 = np.load(args.system_files[0])
    np_embedding2 = np.load(args.system_files[1])
    tensor_emb1 = torch.tensor(np_embedding1).cuda(0) # [N, H]
    tensor_emb2 = torch.tensor(np_embedding2).cuda(0) # [N, H]
    similarity = tensor_emb1.mm(tensor_emb2.transpose(0, 1))
    pred = torch.argmax(similarity, dim=-1)
    label = torch.arange(pred.size(0), dtype=torch.long).to(pred)
    acc = torch.sum((pred == label)).item() / pred.size(0)
    log_out = "Result:\n"
    log_out += "System A: {}\n".format(args.system_files[0])
    log_out += "System B: {}\n".format(args.system_files[1])
    log_out += "Retrieval Acc: {:.2f}\n".format(acc * 100)
    print(log_out)
    return acc

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-dirs', type=str, help="input dirs, the directory contain *.emb (dumped emnbedding) and *.jsonl (data with oracle label)", nargs="+")
    parser.add_argument('--system-files', type=str, help="", nargs="+",
        default=[
            "/home/tiger/xgiga_dumpEmb_finetune_pretrainedMspm4_freezeEncDec_freezeEncLayerNorm_adapterV3_2layer_fusedGated_rougeEval_fixValid_eval_zh/embedding/train_x_en_first10000/sent_embedding.jsonl",
            "/home/tiger/xgiga_dumpEmb_finetune_pretrainedMspm4_freezeEncDec_freezeEncLayerNorm_adapterV3_2layer_fusedGated_rougeEval_fixValid_eval_zh/embedding/train_x_en2zh_first10000/sent_embedding.jsonl"
        ]
    )
    parser.add_argument('--langs', type=str, help="languages that the systems are processing", nargs="+")
    # parser.add_argument('-e', type=str, help="the keyfield of embedding file", default="sent_embedding.jsonl")
    parser.add_argument('-d', type=str, help="the keyfield of data file", default="data.jsonl")
    # parser.add_argument('-m', type=str, help="mode", default="matchAndVisualizeEmb")
    parser.add_argument('-m', type=str, help="mode", default="matchAndCalEmbDistance")
    parser.add_argument('-o', type=str, help="output_dir", default="/opt/tiger/sumtest/graphs/")
    parser.add_argument('-n', type=str, help="output_name", default="")
    parser.add_argument('--metric', type=str, help="type of the distance metric", choices=["cosine", "euclidean", "cka"], default="cka")
    parser.add_argument('--topk', type=int, default=1000)
    parser.add_argument(
        '--n-iter', type=int, help="number of iteration", 
        default=5000
    )
    parser.add_argument('-p', type=int, help="perplexity", default=30)
    parser.add_argument('-t', type=float, default="1e-6", help="discard embedding values less than t")
    args = parser.parse_args()

    eval("{}(args)".format(args.m))

    """
    python3 visualizeEmb.py -m matchAndCalEmbDistance --system-files /home/tiger/xgiga_dumpEmb_encoder/embedding/encoder_dev_fr /home/tiger/xgiga_dumpEmb_encoder/embedding/encoder_dev_zh

    python3 visualizeEmb.py -m visualizeDocEmbPair --system-files /home/tiger/xgiga_dumpEmb_encoder/embedding/encoder_dev_fr /home/tiger/xgiga_dumpEmb_encoder/embedding/encoder_dev_zh

    python3 visualizeEmb.py -m retrievalGivenArray --system-files /home/tiger/xgiga_dumpEmb_encoder/embedding/encoder_dev_en/document_embedding.npy /home/tiger/xgiga_dumpEmb_encoder/embedding/encoder_dev_fr/document_embedding.npy

    python3 visualizeEmb.py -m retrievalGivenArray --system-files /home/tiger/xgiga_dumpEmb_proj/embedding/proj_wiki40b_3_dev_en/document_embedding.npy /home/tiger/xgiga_dumpEmb_proj/embedding/proj_wiki40b_3_dev_fr/document_embedding.npy
    """