import numpy as np
import argparse
from ioFn import readJsonl, dumpJsonl
from typing import List

def json2array(infos: List):
    """
    Input:
        infos: List of json
    Return:
        array: embedding matrix
    """
    embeddings = []
    for info in infos:
        embeddings.append(info["embedding"])
    return np.array(embeddings)

def array2json(embeddings: np.array, input_infos: List):
    """
    Input:
        embeddings: np.ndarray
    Return:
        infos: List of json
    """
    for i in range(len(input_infos)):
        input_infos[i]["embedding"] = embeddings[i].tolist()
    return input_infos

def tokenJson2array(infos: List):
    """
    Input:
        infos: List of json
    Return:
        array: token embedding matrix
        tokens: list of [(doc_id, token_id, token)]
    """
    embeddings = []
    tokens = []
    for (di, info) in enumerate(infos):
        for (i, token_info) in enumerate(info["embeddings"]):
            embeddings.append(token_info["embedding"])
            tokens.append((di, i, token_info["token"]))
    return np.array(embeddings), tokens

def tokenArray2json(embeddings: np.array, input_infos: List):
    """
    Input:
        embeddings: np.ndarray
    Return:
        infos: List of json
    """
    total_id = 0
    for di in range(len(input_infos)):
        for ti in range(len(input_infos[di]["embeddings"])):
           input_infos[di]["embeddings"][ti]["embedding"] = embeddings[total_id].tolist()
           total_id += 1
    return input_infos

def z_norm(inputs, batch_dim=0):
    mean = inputs.mean(batch_dim, keepdims=True)
    var = inputs.var(batch_dim, keepdims=True)
    return (inputs - mean) / np.sqrt(var + 1e-9)

def batchNormEmb(embeddings: np.ndarray, batch_size: int):
    i = 0
    normalized_embeddings = []
    while i < len(embeddings):
        batch_embeddings = embeddings[i:i+batch_size]
        normalized_batch_embeddings = z_norm(batch_embeddings, 0)
        normalized_embeddings.append(normalized_batch_embeddings)
        i += batch_size
    normalized_embeddings = np.concatenate(normalized_embeddings, axis=0)
    return normalized_embeddings

def batchNormMode(args):
    input_infos = readJsonl(args.i)
    embedding_matrix = json2array(input_infos)
    normalized_embedding = batchNormEmb(embedding_matrix, batch_size=8)
    emb_infos = array2json(normalized_embedding, input_infos)
    dumpJsonl(emb_infos, args.o)

def lir(embed, c, r = 8):
    print("embed.shape: ", embed.shape)
    proj = np.matmul(embed, c[:, :r]) / np.sqrt(np.sum(embed ** 2, axis=-1, keepdims=True))
    return embed - np.matmul(proj, np.transpose(c[:, :r]))

def extractComponent(embeddings, r):
    """
    Input:
        embeddings: [N, H], np.ndarray
    Return:
        results: np.ndarray, conpoments
    """
    u, _, _ = np.linalg.svd(np.transpose(embeddings), full_matrices=False)
    return u[:, :r]

def projNormMode(args):
    input_infos = readJsonl(args.i)
    embedding_matrix = json2array(input_infos)
    if args.component:
        component = np.load(args.component)
        print("load component")
        print(component.sum())
    else:
        component = extractComponent(embedding_matrix, r=args.r)
    normalized_embedding = lir(embedding_matrix, component, r=args.r)
    emb_infos = array2json(normalized_embedding, input_infos)
    dumpJsonl(emb_infos, args.o)

def extractComponentMode(args):
    input_infos = readJsonl(args.i)
    embedding_matrix = json2array(input_infos)
    print(embedding_matrix.shape)
    component = extractComponent(embedding_matrix, r=args.r)
    np.save(args.o, component)
    # normalized_embedding = lir(embedding_matrix, component)
    # emb_infos = array2json(normalized_embedding, input_infos)
    # dumpJsonl(emb_infos, args.o)

def projNormTokenMode(args):
    input_infos = readJsonl(args.i)
    embedding_matrix, tokens = tokenJson2array(input_infos)
    component = extractComponent(embedding_matrix, r=args.r)
    normalized_embedding = lir(embedding_matrix, component)
    emb_infos = tokenArray2json(normalized_embedding, input_infos)
    dumpJsonl(emb_infos, args.o)

def centeringMode(args):
    embedding_matrix = None
    if args.i.endswith(".jsonl"):
        input_infos = readJsonl(args.i)
        embedding_matrix = json2array(input_infos)
    elif args.i.endswith(".npy"):
        embedding_matrix = np.load(args.i)
    references = np.load(args.component)
    centriod = np.mean(references, axis=0, keepdims=True)
    results = embedding_matrix - centriod

    import os
    outdir = os.path.split(args.o)[0]
    os.makedirs(outdir, exist_ok=True)
    if args.o.endswith(".npy"):
        np.save(args.o, results)
    else:
        emb_infos = []
        for (i, embedding) in enumerate(results):
            emb_infos.append(
                {"embedding": embedding.tolist()}
            )
        dumpJsonl(emb_infos, args.o)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-i", help="input embedding file (json or npy)")
    parser.add_argument(
        "--component", default=None,
        help="input component file for projection, or reference embedding array for centering, .npy file",
    )
    parser.add_argument("-o", help="output embedding file (json) / np.array")
    parser.add_argument("-m", help="mode", type=str, default="batchNormMode")
    parser.add_argument("-r", help="the number of dumped/used component", type=int, default=4)
    args = parser.parse_args()

    eval("{}(args)".format(args.m))

    # python3 projNormMode.py -i /home/tiger/xgiga_dumpEmb_mbart/embedding/cc100_fr_first1w_layer10/sent_embedding.jsonl -o /home/tiger/xgiga_dumpEmb_mbart/embedding/cc100_fr_first1w_layer10/sent_emb_component.npy -m extractComponentMode

    # python3 normalizeEmb.py -i /home/tiger/xgiga_dumpEmb_mbart/embedding/train_x_en2zh_1w_2w/sent_embedding.jsonl -o /home/tiger/xgiga_dumpEmb_mbart/embedding/train_x_en2zh_1w_2w/sent_emb_proj.jsonl -m projNormMode

# python3 normalizeEmb.py -i /home/tiger/xgiga_dumpEmb_encoder/embedding/encoder_dev_en/document_embedding.npy -o /home/tiger/decentralized_Unv1/en/dev.jsonl -m centeringMode --component /home/tiger/xgiga_dumpEmb_encoder/embedding/wiki40b_encoder_en/document_embedding.npy
# python3 normalizeEmb.py -i /home/tiger/xgiga_dumpEmb_encoder/embedding/encoder_dev_zh/document_embedding.npy -o /home/tiger/decentralized_Unv1/zh/dev.jsonl -m centeringMode --component /home/tiger/xgiga_dumpEmb_encoder/embedding/wiki40b_encoder_zh/document_embedding.npy
# python3 normalizeEmb.py -i /home/tiger/xgiga_dumpEmb_encoder/embedding/encoder_dev_fr/document_embedding.npy -o /home/tiger/decentralized_Unv1/fr/dev.jsonl -m centeringMode --component /home/tiger/xgiga_dumpEmb_encoder/embedding/wiki40b_encoder_fr/document_embedding.npy
