import argparse
import json
import os
import sys
from pathlib import Path
from typing import List, Optional

import datasets
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from peft import PeftModel


def load_ft_checkpoint(
    model_name, checkpoint_dir: Path, checkpoint: str, cache_dir: Path, hf_token: str
):
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        token=hf_token,
        #        cache_dir=cache_dir,
        device_map="auto",
    )
    tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    tokenizer.padding_side = "right"

    base_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        token=hf_token,
        #       cache_dir=cache_dir,
        device_map="auto",
    )
    base_model.resize_token_embeddings(len(tokenizer))

    if checkpoint.isnumeric():
        checkpoint_num = int(checkpoint)

        if checkpoint_num == 0:
            return base_model, tokenizer
        checkpoint_name = f"checkpoint-{checkpoint_num}"
    else:
        checkpoint_name = checkpoint
    model = PeftModel.from_pretrained(base_model, checkpoint_dir / checkpoint_name)
    model = model.merge_and_unload()

    return model, tokenizer


def get_generations(pipeline, input_file: Path, output_file: Optional[Path] = None):
    data = open(input_file).readlines()

    generations = {}

    output_file = open(output_file, "w+") if output_file is not None else None

    for i, d in enumerate(data):
        print(i)
        d = json.loads(d)
        prompt = d["prompt"]

        out = pipeline(prompt)
        # print(i, len(out))
        d["generation"] = out

        if output_file is not None:
            output_file.write(json.dumps(d))
            output_file.write("\n")

    output_file = output_file.close() if output_file is not None else None
    return d


model_name = "meta-llama/Llama-2-7b-chat-hf"
hf_cache = os.environ["HF_HOME"]
hf_token = os.environ["HF_TOKEN"]


def run_evaluate_checkpoint(
    model_dir: Path,
    checkpoint: str,
    input_file: Path,
):
    # TODO: can the model be loaded onto the device directly here instead of being loaded to CPU
    # first and then being put on the GPU?
    model, tokenizer = load_ft_checkpoint(
        model_name,
        model_dir,
        checkpoint,
        hf_cache,
        hf_token,
    )

    gen = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        max_new_tokens=2048,
        do_sample=False,
    )

    output_file = model_dir / f"eval-{checkpoint}/generations_{checkpoint}_{input_file.stem}.jsonl"
    output_file.parent.mkdir(exist_ok=True)

    get_generations(gen, input_file, output_file)


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_dir", type=Path)
    parser.add_argument("--checkpoint", type=str)
    parser.add_argument("--input", type=Path)

    args = parser.parse_args()
    run_evaluate_checkpoint(
        model_dir=args.model_dir,
        checkpoint=args.checkpoint,
        input_file=args.input,
    )
    return 0


if __name__ == "__main__":
    sys.exit(main())
