import numpy as np
from statistics import mean, stdev
from tqdm import tqdm

from utils import augmentation, common, processing, svm


dataset_to_n1 = {
    'huff': 700,
    'clinc': 100,
    'trec': 1000,
    'subj': 1000,
    'sst2': 1000,
    'cov': 1000,
    'snips': 1800,
    'fewrel': 500,
}

def generate_label_to_n(n_0, n_1, num_labels=200):
    d = {}
    for i in range(0, num_labels, 2):
        d[i] = n_0
        d[i+1] = n_1
    return d


def run_experiments(
    aug_type,
    input_folder,
    dataset_name,
    noaug_bert,
    test_x,
    test_y,
    label_to_n,
    n_seeds = 10,
    ):
    
    acc_list = []
    for seed in tqdm(range(n_seeds)):

        if aug_type in ['synonym', 'insert', 'delete', 'swap', 'backtrans']:
            bert_as_dict = common.load_pickle(f"berts/{dataset_name}_trainaug_{aug_type}.pkl")
        else:
            bert_as_dict = noaug_bert

        label_to_embedding_np = processing.get_label_to_embedding_list_unbalanced(
            txt_path = f"{input_folder}/{dataset_name}/train_s{seed}.txt",
            bert_as_dict = bert_as_dict,
            label_to_n = label_to_n,
            aug_type = aug_type,
        )

        if "-extrapolate" in aug_type:
            label_to_embedding_np = augmentation.extrapolate_augmentation(
                label_to_embedding_np,
                num_extrapolations = int(aug_type.split('-')[0]),
            )
        # elif aug_type == "extrapolate-none":
        #     label_to_embedding_np = augmentation.extrapolate_augmentation(label_to_embedding_np, 0)
        # elif aug_type == "extrapolate-half":
        #     label_to_embedding_np = augmentation.extrapolate_augmentation(label_to_embedding_np, int(label_to_n[1]/label_to_n[0]) )
        elif "-interpolate" in aug_type:
            label_to_embedding_np = augmentation.interpolate_augmentation(label_to_embedding_np, int(aug_type.split('-')[0]))
        elif "-linear-delta" in aug_type:
            label_to_embedding_np = augmentation.linear_delta_augmentation(label_to_embedding_np, int(aug_type.split('-')[0]))
        elif "-within-extra" in aug_type:
            label_to_embedding_np = augmentation.within_extra_augmentation(label_to_embedding_np, int(aug_type.split('-')[0]))
        elif "-gaussian" in aug_type:
            label_to_embedding_np = augmentation.gaussian_augmentation(label_to_embedding_np, int(aug_type.split('-')[0]))
        elif "-uniform" in aug_type:
            label_to_embedding_np = augmentation.uniform_augmentation(label_to_embedding_np, int(aug_type.split('-')[0]))

        train_x, train_y = processing.transform_to_xy(label_to_embedding_np)

        acc = svm.train_eval_lr(
            train_x, train_y,
            test_x, test_y,
            num_seeds = 1 if dataset_name in ["huff", "clinc"] else 3,
        )
        acc_list.append(acc)

    print(f"{dataset_name} at {label_to_n[0]} to {label_to_n[1]} with aug {aug_type}: mean={mean(acc_list)*100:.1f}, stdev={stdev(acc_list)*100:.1f}")


if __name__ == "__main__":

    input_folder = "full-datasets"
    for dataset_name in [
        'snips',
        'huff',
        'fewrel',
        # 'trec',
        # 'clinc',
        # 'sst2', 'subj', 'cov', ,
    ]:

        noaug_bert = common.load_pickle(f"berts/{dataset_name}_noaug.pkl")
        test_x, test_y = processing.get_x_y_test(
            txt_path = f"{input_folder}/{dataset_name}/test.txt",
            bert_as_dict = noaug_bert,
        )

        n_1 = dataset_to_n1[dataset_name]
        for label_to_n in [
            # generate_label_to_n(n_0=10, n_1=n_1),
            generate_label_to_n(n_0=20, n_1=n_1),
            # generate_label_to_n(n_0=30, n_1=n_1),
            # generate_label_to_n(n_0=40, n_1=n_1),
            generate_label_to_n(n_0=50, n_1=n_1),
            # generate_label_to_n(n_0=60, n_1=n_1),
            # generate_label_to_n(n_0=80, n_1=n_1),
            # generate_label_to_n(n_0=100, n_1=n_1),
            # generate_label_to_n(n_0=200, n_1=n_1),
            # generate_label_to_n(n_0=300, n_1=n_1),
            # generate_label_to_n(n_0=400, n_1=n_1),
            # generate_label_to_n(n_0=500, n_1=n_1),
        ]:

            for aug_type in [
                # 'no',
                # 'synonym',
                # 'insert',
                # 'swap',
                # 'delete',
                # '5-gaussian',
                # '3-linear-delta',
                # '3-within-extra',
                # '3-interpolate',
                # '5-linear-delta',
                # '5-within-extra',
                # '5-interpolate',
                # '10-gaussian',
                # '10-linear-delta',
                # '10-within-extra',
                # '10-interpolate',
                # '2-gaussian',
                # '4-gaussian',
                # '8-gaussian',
                # '16-gaussian',
                # '32-gaussian',
                # '64-gaussian',
                # '20-gaussian',
                # '3-uniform',
                # '5-uniform',
                # '10-uniform',
                '2-uniform',
                '4-uniform',
                '8-uniform',
                '16-uniform',
                '32-uniform',
                '64-uniform',
                # '20-linear-delta',
                # '20-within-extra',
                # '20-interpolate',
                # '2-extrapolate',
                # '4-extrapolate',
                # '8-extrapolate',
                # '16-extrapolate',
                # '32-extrapolate',
                # '999-extrapolate',
                # 'extrapolate-none',
                # 'extrapolate-half',
            ]:

                run_experiments(
                    aug_type = aug_type,
                    input_folder = input_folder,
                    dataset_name = dataset_name,
                    noaug_bert = noaug_bert,
                    test_x = test_x,
                    test_y = test_y,
                    label_to_n = label_to_n,
                    n_seeds = 5,
                )