import argparse
import io
import pstats
import shutil
import timeit
from cProfile import Profile
from pathlib import Path

import numpy as np
import pandas as pd

import hedal
from hedal.load_heaan import PARAMETER_PRESET

# directory to save profiling results
profile_dir = Path("profile_results")
profile_dir.mkdir(mode=0o775, exist_ok=True)


def generate_hedal_frame(
    context,
    frame_path: Path,
    num_rows: int = 65536,
    num_nn: int = 2,
    num_nv: int = 2,
    num_cn: int = 2,
    num_cell_cls: int = 2,
) -> hedal.HedalFrame:

    nn_list = [f"nn_{idx}" for idx in range(num_nn)]
    nv_list = [f"nv_{idx}" for idx in range(num_nv)]
    cn_list = [f"cn_{idx}" for idx in range(num_cn)]

    header = nn_list + nv_list + cn_list
    cell_header = cn_list
    df = pd.DataFrame(columns=header)
    for col_name in header:
        if col_name in cell_header:
            series = pd.Series(np.random.randint(1, num_cell_cls + 1, size=(num_rows,))).astype(str)
            series[series.sample(frac=0.2).index] = ""
        else:
            series = pd.Series(np.random.randn(num_rows))
            if col_name in nv_list:
                series[series.sample(frac=0.2).index] = np.nan
        df[col_name] = series
    hf = hedal.HedalFrame.from_dataframe(context, df, frame_path, cn_list)
    return hf


def enc_dec(context, frame_path: Path):
    print("[ENCRYPTION_DECRYPTION]")
    hf = generate_hedal_frame(context, frame_path)

    for col_name in ["nn_0", "nv_0", "cn_0"]:
        if col_name == "nn_0":
            print("NumColumn, without vbit")
            col_type = "num_column_wo_vbit"
        elif col_name == "nv_0":
            print("NumColumn, with vbit")
            col_type = "num_column_w_vbit"
        else:
            print("CellColumn")
            col_type = "cell_column"

        pr = Profile()
        st = timeit.default_timer()
        pr.enable()
        hf[col_name].encrypt()
        pr.disable()
        et = timeit.default_timer()
        enc_time = et - st
        s = io.StringIO()
        ps = pstats.Stats(pr, stream=s).sort_stats("tottime")
        ps.print_stats()

        with open(profile_dir / f"profile_encrypt_{col_type}.txt", "w") as f:
            f.write(s.getvalue())

        pr = Profile()
        st = timeit.default_timer()
        pr.enable()
        hf[col_name].decrypt()
        pr.disable()
        et = timeit.default_timer()
        dec_time = et - st
        s = io.StringIO()
        ps = pstats.Stats(pr, stream=s).sort_stats("tottime")
        ps.print_stats()

        with open(profile_dir / f"profile_decrypt_{col_type}.txt", "w") as f:
            f.write(s.getvalue())

        print(f"encrypt: {enc_time: .4f}, decrypt: {dec_time: .4f}sec")


def add(context, frame_path: Path):
    print("[ADDITION]")
    hf = generate_hedal_frame(context, frame_path, num_nn=2, num_nv=2, num_cn=0)
    hf.encrypt()

    combs = [(False, False), (False, True), (True, False), (True, True)]
    for has_vbit1, has_vbit2 in combs:
        s1 = "vbit" if has_vbit1 else "no_vbit"
        s2 = "vbit" if has_vbit2 else "no_vbit"
        col1_name = "nv_0" if has_vbit1 else "nn_0"
        col2_name = "nv_1" if has_vbit2 else "nn_1"

        pr = Profile()
        st = timeit.default_timer()
        pr.enable()
        _ = hf[col1_name] + hf[col2_name]
        pr.disable()
        et = timeit.default_timer()
        add_time = et - st
        print(f"NumColumn, {s1} + {s2}: {add_time:.4f}")
        s = io.StringIO()
        ps = pstats.Stats(pr, stream=s).sort_stats("tottime")
        ps.print_stats()

        with open(profile_dir / f"profile_add_{s1}_+_{s2}.txt", "w") as f:
            f.write(s.getvalue())


def sub(context, frame_path: Path):
    print("[SUBTRACTION]")
    hf = generate_hedal_frame(context, frame_path, num_nn=2, num_nv=2, num_cn=0)
    hf.encrypt()

    combs = [(False, False), (False, True), (True, False), (True, True)]
    for has_vbit1, has_vbit2 in combs:
        s1 = "vbit" if has_vbit1 else "no_vbit"
        s2 = "vbit" if has_vbit2 else "no_vbit"
        col1_name = "nv_0" if has_vbit1 else "nn_0"
        col2_name = "nv_1" if has_vbit2 else "nn_1"

        pr = Profile()
        st = timeit.default_timer()
        pr.enable()
        _ = hf[col1_name] - hf[col2_name]
        pr.disable()
        et = timeit.default_timer()
        add_time = et - st
        print(f"NumColumn, {s1} - {s2}: {add_time:.4f}")
        s = io.StringIO()
        ps = pstats.Stats(pr, stream=s).sort_stats("tottime")
        ps.print_stats()

        with open(profile_dir / f"profile_sub_{s1}_-_{s2}.txt", "w") as f:
            f.write(s.getvalue())


def mul(context, frame_path: Path):
    print("[MULTIPLICATION]")
    hf = generate_hedal_frame(context, frame_path, num_nn=2, num_nv=2, num_cn=0)
    hf.encrypt()

    combs = [(False, False), (False, True), (True, False), (True, True)]
    for has_vbit1, has_vbit2 in combs:
        s1 = "vbit" if has_vbit1 else "no_vbit"
        s2 = "vbit" if has_vbit2 else "no_vbit"
        col1_name = "nv_0" if has_vbit1 else "nn_0"
        col2_name = "nv_1" if has_vbit2 else "nn_1"

        pr = Profile()
        st = timeit.default_timer()
        pr.enable()
        _ = hf[col1_name] * hf[col2_name]
        pr.disable()
        et = timeit.default_timer()
        print(f"NumColumn, {s1} x {s2}: {et - st:.4f}")
        s = io.StringIO()
        ps = pstats.Stats(pr, stream=s).sort_stats("tottime")
        ps.print_stats()

        with open(profile_dir / f"profile_mul_{s1}_*_{s2}.txt", "w") as f:
            f.write(s.getvalue())


def filter(context, frame_path: Path):
    print("[FILTER]")
    hf = generate_hedal_frame(context, frame_path, num_nn=1, num_nv=1, num_cn=1)
    hf.encrypt()

    for has_vbit in [False, True]:
        col_type = "vbit" if has_vbit else "no_vbit"
        col_name = "nv_0" if has_vbit else "nn_0"

        pr = Profile()
        st = timeit.default_timer()
        pr.enable()
        _ = hf[col_name].filter(hf["cn_0"]["1"])
        pr.disable()
        et = timeit.default_timer()
        print(f"NumColumn {col_type}: {et - st:.4f}")
        s = io.StringIO()
        ps = pstats.Stats(pr, stream=s).sort_stats("tottime")
        ps.print_stats()

        with open(profile_dir / f"profile_filter_{col_type}.txt", "w") as f:
            f.write(s.getvalue())


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--job",
        type=str,
        default="all",
        help="job name, default is 'all', which check performances for all operations",
        choices=["all", "enc_dec", "add", "sub", "mul", "filter"],
    )
    parser.add_argument(
        "--preset", type=str, default="FVa", choices=PARAMETER_PRESET,
    )
    parser.add_argument("--keygen", action="store_true")
    parser.add_argument("--bootstrap", action="store_true")
    args = parser.parse_args()

    # setup
    path = Path("profiling")
    path.mkdir(mode=0o775, exist_ok=True)

    # params, context
    params = hedal.HedalParameter.from_preset(args.preset)
    context = hedal.Context(params, make_bootstrappable=False)

    # generate keys
    sk_path = "./keys/secret_keypack"
    pk_path = "./keys/public_keypack"
    if args.keygen:
        hedal.KeyPack.generate_secret_key(context, sk_path)
        hedal.KeyPack.generate_public_key(context, sk_path, pk_path)
    context.load_pk(pk_path)
    context.load_sk(sk_path)
    context.generate_homevaluator()

    if args.job == "enc_dec":
        enc_path = path / "enc_dec"
        enc_dec(context, enc_path)
    elif args.job == "add":
        add_path = path / "add"
        add(context, add_path)
    elif args.job == "sub":
        sub_path = path / "sub"
        sub(context, sub_path)
    elif args.job == "mul":
        mul_path = path / "mul"
        mul(context, mul_path)
    elif args.job == "filter":
        filter_path = path / "filter"
        filter(context, filter_path)
    elif args.job == "all":
        enc_path = path / "enc_dec"
        enc_dec(context, enc_path)
        add_path = path / "add"
        add(context, add_path)
        sub_path = path / "sub"
        sub(context, sub_path)
        mul_path = path / "mul"
        mul(context, mul_path)
        filter_path = path / "filter"
        filter(context, filter_path)

    # remove test files
    shutil.rmtree(path)
