import json
import sys
import numpy as np

def readJsonl(fname):
    datas = []
    with open(fname, 'r') as fin:
        for line in fin:
            datas.append(json.loads(line.strip()))
    return datas

def dumpJsonl(datas, fname):
    with open(fname, 'w') as fout:
        for data in datas:
            fout.write(json.dumps(data, ensure_ascii=False) + '\n')

def readTxt(file):
    datas = []
    with open(file, 'r') as fin:
        for line in fin:
            datas.append(line.strip())
    return datas

def printTokenIdx(inFile):
    indatas = readJsonl(inFile)
    tokens = []
    idxs = []
    for (di, data) in enumerate(indatas):
        for (ti, info) in enumerate(data["embeddings"]):
            tokens.append(info["token"])
            idxs.append([di, ti])
    for (i, (idx, token)) in enumerate(zip(idxs, tokens)):
        print(idx, token, i)


def dumpEmbByIndex(inFile, indexFile, outFile):
    indatas = readJsonl(inFile)
    indexs = readTxt(indexFile)
    embeddings = []
    for (di, data) in enumerate(indatas):
        for (ti, info) in enumerate(data["embeddings"]):
            embeddings.append(info["embedding"])
    results = []
    for idx in indexs:
        results.append(embeddings[int(idx)])
    results = np.array(results)
    np.save(outFile, results)

if __name__ == "__main__":
    inFile = sys.argv[1]
    mode = sys.argv[2]
    if len(sys.argv) > 4:
        indexfile = sys.argv[3]
        outFile = sys.argv[4]

    if mode == "printTokenIdx":
        printTokenIdx(inFile)
    elif mode == "dumpEmbByIndex":
        dumpEmbByIndex(inFile, indexfile, outFile)

    # python3 handleEmbedding.py /home/tiger/xgiga_dumpEmb_finetune_pretrainedMspm4_freezeEncDec_freezeEncLayerNorm_adapterV3_2layer_fusedGated_rougeEval_fixValid_eval_zh/embedding/train_x_en_first10000/token_embedding_projNormalizeEmb_first20.jsonl dumpEmbByIndex /opt/tiger/sumtest/en_idx.txt en_embedding.npy