from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

from utils import cast_to_hidden_state_drop
from gen_utils import generate_with_context

# model_id = 'EleutherAI/pythia-70m-deduped'
model_id = 'meta-llama/Llama-2-7b-hf'
text = 'What is 2 + 2 ? I think it is 4 but I am not sure. Can you please help me with that? I had math class in the morning and I am not sure if I remember it correctly.'

model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map='auto', attn_implementation="eager") 
tokenizer = AutoTokenizer.from_pretrained(model_id)

encodings = tokenizer(text, return_tensors="pt")


# cast the model to hidden state dropout
# the oon


model = cast_to_hidden_state_drop(
    model, 
    keep_ratio=0.98, 
    skip_layers=list(range(32))[:-10], 
    sort_metric='norm',
    sort_descending=True,
    )


response = generate_with_context(model, tokenizer, encodings.input_ids.to('cuda'), temperature=0.1, max_new_tokens=10, adjust_kv_cache=None, plot_final_kv_cache_dims=True)
print(response)
