import argparse
import json
import os
from typing import Tuple

import torch
from transformers import AutoTokenizer

from sentsim.config import ModelArguments
from sentsim.eval.ists import load_instances, preprocess_instances, inference
from sentsim.eval.ists import save_infered_instances

from sentsim.models.models import create_contrastive_learning

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
    "--data-dir",
    type=str,
    help="Filepath for semeval16 task2 directory (test_goldStandard)",
)
parser.add_argument(
    "--source",
    type=str,
    default="images",
    choices=["images", "headlines"],
    help="Source for the dataset",
)
parser.add_argument(
    "--ckpt-dir",
    type=str,
    required=True,
    help="Dirpath for checkpoint (dirpath of specific checkpoint)",
)
parser.add_argument(
    "--ckpt-path",
    type=str,
    help="Dirpath for checkpoint (file named \"pytorch_model.bin\")",
)


def create_filepaths(data_dir: str, source: str) -> Tuple[str, ...]:
    return (
        os.path.join(data_dir, f"STSint.testinput.{source}.sent1.txt"),
        os.path.join(data_dir, f"STSint.testinput.{source}.sent2.txt"),
        os.path.join(data_dir, f"STSint.testinput.{source}.sent1.chunk.txt"),
        os.path.join(data_dir, f"STSint.testinput.{source}.sent2.chunk.txt"),
    )


def main():
    args = parser.parse_args()

    instances = load_instances(*create_filepaths(args.data_dir, args.source))

    with open(os.path.join(args.ckpt_dir, "model_args.json")) as f:
        model_args = ModelArguments(**json.load(f))
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path, use_fast=False
    )
    prep_instances = preprocess_instances(tokenizer, instances)

    module = create_contrastive_learning(model_args)
    if args.ckpt_path is not None:
        module.load_state_dict(torch.load(args.ckpt_path))
    infered_instances = inference(module.model, prep_instances)
    outfile = f"{args.source}.wa" if args.ckpt_path else f"{args.source}.wa.untrained"
    save_infered_instances(infered_instances, os.path.join(args.ckpt_dir, outfile))


if __name__ == "__main__":
    main()
