import sys
import json
import argparse
import numpy as np

def readJsonl(fname, topk=None):
    datas = []
    with open(fname, 'r') as fin:
        counter = 0
        for line in fin:
            datas.append(json.loads(line.strip()))
            counter += 1
            if topk is not None and counter > topk:
                break
    print("read {} instances from {}".format(len(datas), fname))
    return datas

def embNorm(embedding: list):
    embedding = np.array(embedding)
    return np.sqrt(np.sum(np.power(embedding, 2)))

def printEmbNorm(args):
    for input_file in args.input_files:
        print("input_file: {}".format(input_file))
        embedding_infos = readJsonl(input_file, args.topk)
        for info in embedding_infos:
            embeddings = info["embeddings"]
            norm_list = []
            for item in embeddings:
                token = item["token"]
                norm = embNorm(item["embedding"])
                norm_list.append([token, round(norm, 3)])
            norm_list = sorted(norm_list, key=lambda x: x[1], reverse=True)
            print(info["src_str"])
            print(norm_list)
            print()

def pairEmbNormDiff(args):
    from collections import defaultdict
    system_norms = defaultdict(list)
    src_strs = []
    for (i, input_file) in enumerate(args.input_files):
        embedding_infos = readJsonl(input_file, args.topk)
        for info in embedding_infos:
            if  i == 0: src_strs.append(info['src_str'])
            embeddings = info["embeddings"]
            norm_dic = dict()
            for item in embeddings:
                token = item["token"]
                norm = embNorm(item["embedding"])
                norm_dic[token] = norm
            system_norms[input_file].append(norm_dic)
    
    # calculate norm diff
    system1 = system_norms[args.input_files[0]]
    system2 = system_norms[args.input_files[1]]
    for (i, src_str) in enumerate(src_strs):
        print(src_str)
        norm_diff = []
        for token in system1[i].keys():
            norm_diff.append([token, round(system1[i][token] - system2[i][token], 3)])
        norm_diff = sorted(norm_diff, key=lambda x: x[1], reverse=True)
        print(norm_diff)
        print()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input-files", nargs="+", default=[
        "/home/tiger/xgiga_dumpEmb_mbart/embedding/test_x_zh/token_embedding.jsonl",
        "/home/tiger/xgiga_dumpEmb_finetune_pretrainedMspm4_freezeDecoder_bsz32i/embedding/test_x_zh/token_embedding.jsonl"
    ])
    parser.add_argument("--topk", default=5, type=int, help="analyze the token embedding of first ${topk} instances")
    parser.add_argument("-m", default="pairEmbNormDiff", type=str, help="mode")
    args = parser.parse_args()

    eval("{}(args)".format(args.m))
