# %%
import json
from tqdm import tqdm
import itertools
import pickle

import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import matplotlib.pyplot as plt

from propane_reference import (
  extract_prompt_from_template,
  build_prompt,
  DocDataset,
  compute_dataset_kl,
)

from transformers import logging

logging.set_verbosity_error()

plt.rcParams.update(plt.rcParamsDefault)
plt.rcParams.update(
  {
    "text.usetex": False,
    "font.family": "sans-serif",
    "font.sans-serif": "Helvetica",
    "figure.dpi": 300,
    "font.size": 15,
  }
)

models = [
  "lmsys/vicuna-7b-v1.5",
  "mistralai/Mistral-7B-Instruct-v0.2",
  "EleutherAI/pythia-6.9b",
  "microsoft/phi-2",
]


# %%
def get_orig_prompt(model_name: str, prompt_id: int) -> str:
  if model_name == "lmsys/vicuna-7b-v1.5":
    with open(
      "../experiments/data/100_hard_vicuna_7b_v1.5/results.json"
    ) as f:
      d = json.load(f)

    for id in d:
      if id == str(prompt_id):
        return d[id]["orig_prompt"]

  elif model_name == "EleutherAI/pythia-6.9b":
    with open("../experiments/data/pythia/orig_prompts.json") as f:
      d = json.load(f)
    return [i for i in d if i["id"] == prompt_id][0]["prompt"]

  elif model_name == "microsoft/phi-2":
    # d = pickle.load(
    #   open(
    #     "/mnt/sata_4tb1/home/research/prompt_reconstruction/extra_experiments_for_acl/phi_warm_from_mistral/coding_hermes_phi_100_dset.pkl",
    #     "rb",
    #   )
    # )
    # print(vars(d[0]["dataset"]))
    with open(
      "/mnt/sata_4tb1/home/research/prompt_reconstruction/extra_experiments_for_acl/phi_warm_from_mistral/coding_phi_suggested_mistral_warmstart.json",
      "r",
    ) as f:
      d = json.load(f)

    for cur in d:
      if cur["id"] == prompt_id:
        return cur["orig_prompt"]

  elif model_name == "mistralai/Mistral-7B-Instruct-v0.2":  # mistral
    with open(
      "/mnt/sata_4tb1/home/research/prompt_reconstruction/extra_experiments_for_acl/mistral_reconstruction/gpt_suggested.json",
      "r",
    ) as f:
      d = json.load(f)

    for cur in d:
      if cur["id"] == prompt_id:
        return cur["prompt"]

  raise ValueError("invalid model")


# %%
# def get_optim_prompt(model_name: str, prompt_id: int) -> torch.Tensor:
#   if model_name == "lmsys/vicuna-7b-v1.5":
#     return torch.load(
#       f"/mnt/sata_4tb1/home/research/prompt_reconstruction/experiments/data/100_hard_vicuna_7b_v1.5/hard_arbitrary/hard_ids_prompt_{prompt_id}_len_32_docs_100_trial_0.pt"
#     )

#   elif model_name == "EleutherAI/pythia-6.9b":
#     return torch.load(
#       f"/mnt/sata_4tb1/home/research/prompt_reconstruction/experiments/data/pythia/6.9b/hard_ids_prompt_{prompt_id}_len_32_docs_100_trial_0.pt"
#     )

#   elif model_name == "microsoft/phi-2":
#     raise ValueError("not doing phi since we didntsave the pt")
#     # return torch.load(f"")

#   elif model_name == "mistralai/Mistral-7B-Instruct-v0.2":  # mistral
#     return torch.load(
#       f"/mnt/sata_4tb1/home/research/prompt_reconstruction/extra_experiments_for_acl/mistral_reconstruction/results/results_ids_{prompt_id}.pt"
#     )


#   raise ValueError("invalid model")


def get_optim_prompt(model_name: str, prompt_id: int) -> str:
  if model_name == "lmsys/vicuna-7b-v1.5":
    with open(
      "../experiments/data/100_hard_vicuna_7b_v1.5/hard_arbitrary/hard_results.json"
    ) as f:
      d = json.load(f)
    return [i for i in d if i["prompt_id"] == prompt_id][0]["results"][-1]["suffix"]

  elif model_name == "EleutherAI/pythia-6.9b":
    with open("../experiments/data/pythia/6.9b/hard_results.json") as f:
      d = json.load(f)
    return [i for i in d if i["prompt_id"] == prompt_id][0]["results"][-1]["suffix"]

  elif model_name == "microsoft/phi-2":
    return json.load(
      open(
        f"/mnt/sata_4tb1/home/research/prompt_reconstruction/extra_experiments_for_acl/phi_warm_from_mistral/results_phi_code_mistral_warm/gamma_0.0/results_{prompt_id}.json"
      )
    )[-1]["prompt"]

  elif model_name == "mistralai/Mistral-7B-Instruct-v0.2":  # mistral
    return json.load(
      open(
        f"/mnt/sata_4tb1/home/research/prompt_reconstruction/extra_experiments_for_acl/mistral_reconstruction/results/results_{prompt_id}.json",
        "r",
      )
    )[-1]["prompt"]

  raise ValueError("invalid model")


# %%
# Run the token positional replacements
res = {}
tot_prompts = 100
max_pos = 6

for model_name in models:
  res[model_name] = {}
  model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.float16, device_map="cuda:0"
  )
  tokenizer = AutoTokenizer.from_pretrained(model_name)

  for i in range(max_pos):
    res[model_name][i] = {
      "optim": [],
      "orig": [],
    }

  for i in tqdm(range(tot_prompts), total=tot_prompts, desc=f"model: {model_name}"):
    # older experiments had the suffix in there too for vicuna

    orig_prompt, orig_prompt_slice = build_prompt(
      model_name, get_orig_prompt(model_name, i), tokenizer
    )
    optim_prompt, optim_prompt_slice = build_prompt(
      model_name, get_optim_prompt(model_name, i), tokenizer
    )

    if model_name == "lmsys/vicuna-7b-v1.5":
      optim_prompt, optim_prompt_slice = build_prompt(
        model_name,
        extract_prompt_from_template(
          get_optim_prompt(model_name, i), "lmsys/vicuna-7b-v1.5"
        ),
        tokenizer,
      )

    for i in range(max_pos):
      replaced_orig = orig_prompt.clone()
      replaced_optim = optim_prompt.clone()
      replaced_orig[:, i + orig_prompt_slice.start] = tokenizer.unk_token_id
      replaced_optim[:, i + optim_prompt_slice.start] = tokenizer.unk_token_id

      dataset_optim = DocDataset(
        model,
        tokenizer,
        optim_prompt,
        replaced_optim,
        n_docs=100,
        doc_len=32,
        gen_batch_size=100,
      )
      dataset_optim.orig_prompt_slice = optim_prompt_slice
      dataset_optim.prompt_slice = optim_prompt_slice
      kl_optim, std_optim = compute_dataset_kl(model, dataset_optim, batch_size=100)
      res[model_name][i]["optim"].append(kl_optim)

      dataset_orig = DocDataset(
        model,
        tokenizer,
        orig_prompt,
        replaced_orig,
        n_docs=100,
        doc_len=32,
        gen_batch_size=100,
      )
      dataset_orig.orig_prompt_slice = orig_prompt_slice
      dataset_orig.prompt_slice = orig_prompt_slice
      kl_orig, std_orig = compute_dataset_kl(model, dataset_orig, batch_size=100)
      res[model_name][i]["orig"].append(kl_orig)

  del model
  torch.cuda.empty_cache()


# %%
with open("position_results.json", "w") as f:
  json.dump(res, f, indent=4)
# %%

# %%

# %%

# %%
