import os
import sys
import numpy as np
sys.path.append("..")
from visualizeEmb import linear_CKA, readJsonl

# python3 pipeline_v1proj.py mbartV2_iadapterfix_adapterV1_encProjFix_projLN_newTA

def loadEmb(infile):
    """
    Input:
        infile: jsonl file
    """
    emb_list = []
    emb_infos = readJsonl(infile)
    for emb_info in emb_infos:
        emb_list.append(emb_info["embedding"])
    return emb_list


def generate_suffixs(interval=8000, end=112000, epoch=15625):
    current = 16000
    results = []
    while current <= end:
        epoch_idx = current // epoch + 1
        result = "{}_{}".format(epoch_idx, current)
        results.append(result)
        current += interval
    return results


if __name__ == "__main__":
    prefix = sys.argv[1]
    split = "test"
    suffixs = generate_suffixs()
    outputfile = "{}_dis.out".format(prefix)
    with open(outputfile, 'a') as fout:
        for suffix in suffixs:
            if suffix != "1_8000":
                command = "bash dumpEmbV1proj.sh {prefix} {suffix} {split}".format(
                    prefix=prefix, split=split, suffix=suffix
                )
                os.system(command)
            logging_str = suffix
            for tlg in ["zh", "fr"]:
                embfile_prefix = "/home/tiger/xgiga_dumpEmb_{}/embedding/".format(prefix)
                en_emb_file = os.path.join(
                    embfile_prefix, "{split}_x_en_{tlg}2en_{suffix}/document_embedding.jsonl".format(
                        split=split, tlg=tlg, suffix=suffix
                    )
                )
                tlg_emb_file = os.path.join(
                    embfile_prefix, "{split}_x_en2{tlg}_{tlg}_{suffix}/document_embedding.jsonl".format(
                        split=split, tlg=tlg, suffix=suffix
                    )
                )

                en_emb = loadEmb(en_emb_file)
                tlg_emb = loadEmb(tlg_emb_file)
                print("number of en emb: ", len(en_emb))
                print("number of {} emb: ".format(tlg), len(tlg_emb))
                en_emb = np.array(en_emb)
                tlg_emb = np.array(tlg_emb)
                score = linear_CKA(en_emb, tlg_emb)
                logging_str += "\t{:.2f}".format(score)
            fout.write(logging_str + "\n")
