import tarfile
import argparse 

from allennlp.models.archival import archive_model

def main():
    args = argument_parsing()

    if args.attack_target == "stop_token":
        # specify serialization dir
        attack_folder = 'stop_token_attack_model'
        regularized_folder = 'stop_token_regularized_model'
        baseline_folder = 'stop_token_baseline_model'
        combined_folder = 'stop_token_combined_model'
        simple_combined_folder = 'stop_token_simple_combined_model'

        serialization_dir_attack = "./bert_models/{}".format(attack_folder)
        serialization_dir_regularized = "./bert_models/{}".format(regularized_folder)
        serialization_dir_baseline = "./bert_models/{}".format(baseline_folder)

        archive_model(serialization_dir_attack, archive_path="./bert_models/{}/{}.tar.gz".format(attack_folder, attack_folder))
        archive_model(serialization_dir_regularized, archive_path="./bert_models/{}/{}.tar.gz".format(regularized_folder, regularized_folder))
        archive_model(serialization_dir_baseline, archive_path="./bert_models/{}/{}.tar.gz".format(baseline_folder, baseline_folder))

        serialization_file_combined = './bert_models/{}/{}.tar.gz'.format(combined_folder, combined_folder)
        config_file_combined = "./bert_models/{}/config.json".format(combined_folder)
        with tarfile.open(serialization_file_combined, "w:gz") as archive:
            archive.add(config_file_combined, arcname="config.json")

        serialization_file_simple_combined = './bert_models/{}/{}.tar.gz'.format(simple_combined_folder, simple_combined_folder)
        config_file_simple_combined = "./bert_models/{}/config.json".format(simple_combined_folder)
        with tarfile.open(serialization_file_simple_combined, "w:gz") as archive:
            archive.add(config_file_simple_combined, arcname="config.json")

    elif args.attack_target == "first_token":
        attack_folder = 'first_token_attack_model'
        regularized_folder = 'first_token_regularized_model'
        baseline_folder = 'first_token_baseline_model'
        combined_folder = 'first_token_combined_model'
        simple_combined_folder = 'first_token_simple_combined_model'

        serialization_dir_attack = "./bert_models/{}".format(attack_folder)
        serialization_dir_regularized = "./bert_models/{}".format(regularized_folder)
        serialization_dir_baseline = "./bert_models/{}".format(baseline_folder)

        archive_model(serialization_dir_attack, archive_path="./bert_models/{}/{}.tar.gz".format(attack_folder, attack_folder))
        archive_model(serialization_dir_regularized, archive_path="./bert_models/{}/{}.tar.gz".format(regularized_folder, regularized_folder))
        archive_model(serialization_dir_baseline, archive_path="./bert_models/{}/{}.tar.gz".format(baseline_folder, baseline_folder))

        serialization_file_combined = './bert_models/{}/{}.tar.gz'.format(combined_folder, combined_folder)
        config_file_combined = "./bert_models/{}/config.json".format(combined_folder)
        with tarfile.open(serialization_file_combined, "w:gz") as archive:
            archive.add(config_file_combined, arcname="config.json")

        serialization_file_simple_combined = './bert_models/{}/{}.tar.gz'.format(simple_combined_folder, simple_combined_folder)
        config_file_simple_combined = "./bert_models/{}/config.json".format(simple_combined_folder)
        with tarfile.open(serialization_file_simple_combined, "w:gz") as archive:
            archive.add(config_file_simple_combined, arcname="config.json")


def argument_parsing():
    parser = argparse.ArgumentParser(description='One argparser')
    parser.add_argument('--attack_target', type=str, choices=['first_token', 'stop_token'], help='Which kind of model to archive')
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    main()