import os
import time
import json
import argparse
from itertools import product


METRIC_FOR_BEST_MODEL = {
    "mnli": "eval_accuracy",
    "qqp": "eval_f1",
    "qnli": "eval_accuracy",
    "sst2": "eval_accuracy",
    "cola": "eval_matthews_correlation",
    "stsb": "eval_spearmanr",
    "mrpc": "eval_f1",
    "rte": "eval_accuracy",
    "wnli": "eval_accuracy"
}

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run in arrays.")
    parser.add_argument("--seeds", type=int, required=True, nargs="+")
    parser.add_argument("--tasks", type=str, required=True, nargs="+")
    parser.add_argument(
        "--default_adapter_non_linearities", type=str, nargs="+",
        default=['relu']
    )
    args = parser.parse_args()

    for task in args.tasks:
        assert task in METRIC_FOR_BEST_MODEL, f"Task {task} is not supported."

    # Arguments.
    iter_args = product(
        args.seeds, args.tasks, args.default_adapter_non_linearities, range(13)
    )

    for seed, task, non_linearity, last_n in iter_args:
        print(json.dumps({
            "seed": seed,
            "task_name": task,
            "model_name_or_path": "bert-base-uncased",
            "do_train": True,
            "do_eval": True,
            "do_predict": True,
            "max_seq_length": 128,
            "per_device_train_batch_size": 32,
            "per_device_eval_batch_size": 32,
            "learning_rate": 1e-4,
            "num_train_epochs": 10,
            "output_dir": "output/rte_%t",
            "overwrite_output_dir": True,
            "train_adapter": True,
            "load_best_model_at_end": True,
            "save_strategy": "epoch",
            "save_total_limit": 2,
            "evaluation_strategy": "epoch",
            "logging_steps": 20,
            "metric_for_best_model": METRIC_FOR_BEST_MODEL[task],
            "adapter_last_layers": last_n,
            "low_resources": 2048,
            "default_adapter_non_linearity": non_linearity
        }))
