from tqdm import tqdm

from utils import bert, common, eda


def cache_bert_aug(
    dataset_path,
    aug_output_path,
    aug_type,
):

    input_path = f"{dataset_path}/train.txt"

    bert_as_dict = {}
    aug_method = eda.get_aug_method(aug_type)

    lines = open(input_path, 'r').readlines()

    for line in tqdm(lines):
        parts = line.strip().split('\t')
        sentence = parts[1]
        augmented_sentences = aug_method(
            sentence,
            n_aug = 4,
            alpha = 0.1,
        )

        embeddings = []
        for s in [sentence] + augmented_sentences:
            embeddings.append(bert.get_embedding(s))
        bert_as_dict[sentence] = embeddings
    
    common.save_pickle(aug_output_path, bert_as_dict)
    print(f"augmented bert as dict len {len(bert_as_dict)} saved in {aug_output_path}")


def cache_bert_noaug(
    dataset_path,
    output_path,
    ):

    input_paths = [
        f"{dataset_path}/train.txt",
        f"{dataset_path}/test.txt",
    ]

    bert_as_dict = {}
    for input_path in input_paths:
        lines = open(input_path, 'r').readlines()

        for line in tqdm(lines):
            parts = line.strip().split('\t')
            sentence = parts[1]
            embedding = bert.get_embedding(sentence)
            bert_as_dict[sentence] = embedding
    
    common.save_pickle(output_path, bert_as_dict)
    print(f"bert as dict len {len(bert_as_dict)} saved in {output_path}")


if __name__ == "__main__":

    output_folder = "berts"

    for dataset_name in [
        "snips",
        "fewrel",
        "huff",
        # "subj", "sst2", "cov", "trec", "clinc"
    ]:

        # cache_bert_noaug(
        #     dataset_path = f"full-datasets/{dataset_name}",
        #     output_path = f"{output_folder}/{dataset_name}_noaug.pkl",
        # )

        for aug_type in ['backtrans']: #['delete', 'synonym', 'insert', 'swap']
            cache_bert_aug(
                dataset_path = f"full-datasets/{dataset_name}",
                aug_output_path = f"{output_folder}/{dataset_name}_trainaug_{aug_type}.pkl",
                aug_type = aug_type,
            )