import argparse
import json
import os

import torch
import transformers

from debias_eval.benchmark.stereoset import StereoSetRunner
from debias_eval.model import models
from debias_eval.util import generate_experiment_id, _is_generative, _is_self_debias


thisdir = os.path.dirname(os.path.realpath(__file__))
parser = argparse.ArgumentParser(description="Runs StereoSet benchmark.")
parser.add_argument(
    "--persistent_dir",
    action="store",
    type=str,
    default=os.path.realpath(os.path.join(thisdir, "..")),
    help="Directory where all persistent data will be stored.",
)
parser.add_argument(
    "--intrasentence_model",
    action="store",
    type=str,
    default="BertForMaskedLM",
    choices=[
        "SentenceDebiasBertForMaskedLM",
        "SentenceDebiasAlbertForMaskedLM",
        "SentenceDebiasRobertaForMaskedLM",
        "SentenceDebiasGPT2LMHeadModel",
        "INLPBertForMaskedLM",
        "INLPAlbertForMaskedLM",
        "INLPRobertaForMaskedLM",
        "INLPGPT2LMHeadModel",
        "CDABertForMaskedLM",
        "CDAAlbertForMaskedLM",
        "CDARobertaForMaskedLM",
        "CDAGPT2LMHeadModel",
        "DropoutBertForMaskedLM",
        "DropoutAlbertForMaskedLM",
        "DropoutRobertaForMaskedLM",
        "DropoutGPT2LMHeadModel",
        "SelfDebiasGPT2LMHeadModel",
        "SelfDebiasBertForMaskedLM",
        "SelfDebiasAlbertForMaskedLM",
        "SelfDebiasRobertaForMaskedLM",
    ],
    help="Model to evalute (e.g., SentenceDebiasBertForMaskedLM).",
)
parser.add_argument(
    "--model_name_or_path",
    action="store",
    type=str,
    default="bert-base-uncased",
    choices=["bert-base-uncased", "albert-base-v2", "roberta-base", "gpt2"],
    help="HuggingFace model name or path (e.g., bert-base-uncased). Checkpoint from which a "
    "model is instantiated.",
)
parser.add_argument(
    "--bias_direction",
    action="store",
    type=str,
    help="Path to the file containing the pre-computed bias direction for SentenceDebias.",
)
parser.add_argument(
    "--projection_matrix",
    action="store",
    type=str,
    help="Path to the file containing the pre-computed projection matrix for INLP.",
)
parser.add_argument(
    "--load_path",
    action="store",
    type=str,
    help="Path to saved ContextDebias, CDA, or Dropout model checkpoint.",
)
parser.add_argument(
    "--score_type",
    action="store",
    type=str,
    default="likelihood",
    choices=["likelihood", "effect-size"],
    help="The StereoSet scoring mechanism to use.",
)
parser.add_argument(
    "--split",
    action="store",
    type=str,
    default="dev",
    choices=["dev", "test"],
    help="The StereoSet split to use.",
)
parser.add_argument(
    "--batch_size",
    action="store",
    type=int,
    default=1,
    help="The batch size to use during StereoSet intrasentence evaluation.",
)
parser.add_argument(
    "--bias_type",
    action="store",
    type=str,
    choices=["gender", "religion", "race"],
    help="The type of bias to mitigate.",
)


if __name__ == "__main__":
    args = parser.parse_args()

    experiment_id = generate_experiment_id(
        name="stereoset",
        intrasentence_model=args.intrasentence_model,
        model_name_or_path=args.model_name_or_path,
        score_type=args.score_type,
        data_split=args.split,
        bias_type=args.bias_type,
    )

    print("Running StereoSet:")
    print(f" - persistent_dir: {args.persistent_dir}")
    print(f" - intrasentence_model: {args.intrasentence_model}")
    print(f" - model_name_or_path: {args.model_name_or_path}")
    print(f" - bias_direction: {args.bias_direction}")
    print(f" - projection_matrix: {args.projection_matrix}")
    print(f" - load_path: {args.load_path}")
    print(f" - score_type: {args.score_type}")
    print(f" - split: {args.split}")
    print(f" - batch_size: {args.batch_size}")
    print(f" - bias_type: {args.bias_type}")

    kwargs = {}
    if args.bias_direction is not None:
        # Load the pre-computed bias direction for SentenceDebias.
        bias_direction = torch.load(args.bias_direction)
        kwargs["bias_direction"] = bias_direction

    if args.projection_matrix is not None:
        # Load the pre-computed projection matrix for INLP.
        projection_matrix = torch.load(args.projection_matrix)
        kwargs["projection_matrix"] = projection_matrix

    intrasentence_model = getattr(models, args.intrasentence_model)(
        args.load_path or args.model_name_or_path, **kwargs
    )

    if _is_self_debias(args.intrasentence_model):
        intrasentence_model._model.eval()
    else:
        intrasentence_model.eval()

    tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_name_or_path)

    # Use self-debiasing name.
    bias_type = args.bias_type
    if bias_type == "race":
        bias_type = "race-color"

    runner = StereoSetRunner(
        intrasentence_model=intrasentence_model,
        tokenizer=tokenizer,
        input_file=f"{args.persistent_dir}/data/stereoset/{args.split}.json",
        model_name_or_path=args.model_name_or_path,
        score_type=args.score_type,
        batch_size=args.batch_size,
        is_generative=_is_generative(args.intrasentence_model),
        is_self_debias=_is_self_debias(args.intrasentence_model),
        bias_type=bias_type,
    )
    results = runner()

    os.makedirs(f"{args.persistent_dir}/results/stereoset", exist_ok=True)
    with open(
        f"{args.persistent_dir}/results/stereoset/{experiment_id}.json", "w"
    ) as f:
        json.dump(results, f, indent=2)
