import argparse
import shutil
import timeit
from pathlib import Path
from typing import Tuple

import numpy as np
import pandas as pd
from sklearn.metrics import classification_report, roc_auc_score

import hedal
import hedal.ml as hml
from hedal.load_heaan import PARAMETER_PRESET
from hedal.ml import preprocessing


def load_data(
    train_path: Path, val_path: Path, test_path: Path
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    df_train = pd.read_csv(train_path)
    df_val = pd.read_csv(val_path)
    df_test = pd.read_csv(test_path)
    y_train = df_train.pop("label").to_numpy().astype(int)
    y_val = df_val.pop("label").to_numpy().astype(int)
    y_test = df_test.pop("label").to_numpy().astype(int)
    X_train = df_train.to_numpy()
    X_val = df_val.to_numpy()
    X_test = df_test.to_numpy()
    print(f"Number of samples - train: {X_train.shape[0]}, val: {X_val.shape[0]}, test: {X_test.shape[0]}")
    return X_train, X_val, X_test, y_train, y_val, y_test


def _get_directory_size(path: Path):
    directory_size = sum(f.stat().st_size for f in path.glob("**/*") if f.is_file())
    return directory_size


def run_snips_experiment(
    generate_keys: bool = False,
    heaan_preset: str = "FGb",
    model_type: str = "multinomial",
    with_piheaan: bool = False,
    encrypted_train: bool = False,
    batch_size: int = 0,
    num_epoch: int = 5,
):
    # setup
    print("Setup: generate keys, build context, ...")
    logit_path = Path("logit_reg_snips")
    train_path = logit_path / "train"
    model_path = logit_path / "model"
    val_path = logit_path / "val"
    test_path = logit_path / "test"
    val_report_path = logit_path / "val_report"
    test_report_path = logit_path / "test_report"

    heaan_type = "pi" if with_piheaan else "real"
    secret_key_path = Path(f"keys-{heaan_preset}-{heaan_type}/secret_keypack")
    public_key_path = Path(f"keys-{heaan_preset}-{heaan_type}/public_keypack")

    params = hedal.HedalParameter.from_preset(heaan_preset, heaan_type)
    context = hedal.Context(params, make_bootstrappable=True)

    if generate_keys:
        hedal.KeyPack.generate_secret_key(context, secret_key_path)
        hedal.KeyPack.generate_public_key(context, secret_key_path, public_key_path)
    context.load_sk(secret_key_path)
    context.load_pk(public_key_path)
    context.generate_homevaluator()

    classes = [0, 1, 2, 3, 4, 5, 6]

    # load data
    print("Load data ...")
    train_data_path = Path("examples") / "snips_preprocessed/train.csv"
    val_data_path = Path("examples") / "snips_preprocessed/val.csv"
    test_data_path = Path("examples") / "snips_preprocessed/test.csv"
    X_train, X_val, X_test, y_train, y_val, y_test = load_data(train_data_path, val_data_path, test_data_path)

    # preprocessing
    print("Preprocessing ...")
    scale_type = "normal"
    if model_type == "ovr":
        Z_train = preprocessing.fit_train_set_ovr(
            context, X_train, y_train, path=train_path, scale_type=scale_type, multiclass=True,
        )
    elif model_type == "multinomial":
        Z_train = preprocessing.fit_train_set_multinomial(
            context, X_train, y_train, path=train_path, scale_type=scale_type
        )
    train_data_level = None
    if encrypted_train:
        Z_train.encrypt()
        train_data_level = 3
    Z_train.save(target_level=train_data_level)
    print(f"Train data size: {_get_directory_size(train_path) / 1024 / 1024 / 1024 :.4f} GB")

    # set model
    print("Set model ...")
    if model_type == "ovr":
        model = hml.LogisticRegressionOVR(context, path=model_path, scale_type=scale_type)
    else:
        model = hml.LogisticRegressionMultinomial(context, classes=classes, path=model_path, scale_type=scale_type)

    # train
    print("Train ...")
    st = timeit.default_timer()
    print("train set bootstrap")
    Z_train.bootstrap()
    print("train set bootstrap finished")
    lr = 2.0
    model.fit(Z_train, learning_rate=lr, num_epoch=num_epoch, batch_size=batch_size)
    if encrypted_train:
        model.decrypt()
    model.save()
    et = timeit.default_timer()
    train_time = et - st
    print(f"Train takes {train_time:.2f} seconds for total {num_epoch} epochs.")
    if Z_train.encrypted:
        Z_train.decrypt()

    # validation & test
    auc, acc, pred_time = {}, {}, {}
    cls_report = {}
    for key in ["val", "test"]:
        if key == "val":
            X = X_val
            y = y_val
        else:
            X = X_test
            y = y_test

        # test
        key_str = "Validation" if key == "val" else "Test"
        print(f"{key_str} ...")
        st = timeit.default_timer()
        z_path = val_path if key == "val" else test_path
        Z_test = hedal.HedalMatrix.from_ndarray(context, X, path=z_path)
        Z_test.encrypt()
        test_data_level = 4 if model_type == "multinomial" else 3
        Z_test.level_down(target_level=test_data_level)

        report_path = val_report_path if key == "val" else test_report_path
        report = model.predict(Z_test, path=report_path, last_activation=False)
        report.decrypt()
        if model_type == "ovr":
            report_df = report.to_dataframe(apply_sigmoid=True, normalize=False)
        else:
            report_df = report.to_dataframe(apply_softmax=True)
        et = timeit.default_timer()
        test_time = et - st
        print(f"{key_str} takes {test_time:.2f} seconds.")
        pred_time[key] = test_time

        # compute accuracy
        # for each label (ovr)
        if model_type == "ovr":
            for label in classes:
                report_prob = report_df[str(label)]
                report_pred = [1 if report_prob.iloc[i] > 0.5 else -1 for i in range(len(report_prob))]
                report_pred = np.array(report_pred)
                y_test_label = np.where(y == int(label), 1, -1)
                correct_cnt = (report_pred == y_test_label).sum()
                class_acc = correct_cnt / len(y_test_label)
                print(f"Label {label} {key} acc: {class_acc * 100: .2f}%")

        # for all label
        report_arr = report_df.to_numpy()
        predictions = report_arr.argmax(axis=1)
        correct_cnt = (predictions == y).sum()
        acc_ = correct_cnt / len(y)
        print(f"{key_str} accuracy: {acc_ * 100: .2f}%")
        acc[key] = acc_

        # classification report
        print(classification_report(y, predictions, digits=4))
        cls_report[key] = classification_report(y, predictions, digits=4)

        # auc
        macro_auc = None
        if model_type == "ovr":
            macro_auc = 0.0
            for idx, c in enumerate(classes):
                report_prob = report_df[str(c)]
                class_auc = roc_auc_score(y == int(c), report_prob)
                macro_auc += class_auc
            macro_auc /= len(classes)
            print(f"AUC: {macro_auc: .4f}")
        else:
            macro_auc = roc_auc_score(y_test, report_arr, multi_class="ovr")
            print(f"AUC: {macro_auc: .4f}")
        auc[key] = macro_auc

    # log data
    result_log = f"""
        heaan preset: {heaan_preset},
        with piheaan: {with_piheaan},
        model type: {model_type},
        num epoch: {num_epoch},
        batch size: {batch_size},
        learning rate: {lr},
        encrypted train: {encrypted_train},
        train_time: {train_time:.2f} seconds,
        val_time: {test_time:.2f} seconds,
        val acc: {acc['val']: .4f},
        val auc: {auc['val']: .4f},
        val classification report:
        {cls_report['val']},
        test_time: {test_time:.2f} seconds,
        test acc: {acc['test']: .4f},
        test auc: {auc['test']: .4f},
        test classification report:
        {cls_report['test']},
    """
    log_path = Path(
        f"./logs/{heaan_preset}_{with_piheaan}_{model_type}_{num_epoch}_{batch_size}_{lr}_{encrypted_train}.log"
    )
    with open(log_path, "w") as f:
        f.write(result_log)

    # remove data
    shutil.rmtree(logit_path)

    if with_piheaan:
        print(context.history(verbose=False))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run Iris experiment")
    parser.add_argument("--generate_keys", action="store_true")
    parser.add_argument("--heaan_preset", type=str, default="FGb", choices=PARAMETER_PRESET)
    parser.add_argument("--with_piheaan", action="store_true")
    parser.add_argument("--model_type", type=str, default="multinomial", choices=["ovr", "multinomial"])
    parser.add_argument("--batch_size", type=int, default=0)
    parser.add_argument("--num_epoch", type=int, default=5)
    parser.add_argument("--encrypted_train", action="store_true")
    args = parser.parse_args()
    run_snips_experiment(
        generate_keys=args.generate_keys,
        heaan_preset=args.heaan_preset,
        model_type=args.model_type,
        with_piheaan=args.with_piheaan,
        encrypted_train=args.encrypted_train,
        batch_size=args.batch_size,
        num_epoch=args.num_epoch,
    )
