from transformers import AutoTokenizer
import transformers
import torch

model = "../models/Llama-2-7b-chat-hf"

system_prompt = ''

tokenizer = AutoTokenizer.from_pretrained(model)
pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    torch_dtype=torch.float16,
    device_map="auto",
)

user_msg = input('>>')
prompt = f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n{user_msg} [/INST]'
prompt_len = len(prompt)

sequences = pipeline(
    prompt,
    do_sample=True,
    top_k=10,
    num_return_sequences=1,
    eos_token_id=tokenizer.eos_token_id,
    max_length=1024,
)
response = sequences[0]['generated_text'][prompt_len:].strip()
print(response)

while True:
    user_msg = input('>>')
    prompt += f' {response} </s><s> [INST] {user_msg} [/INST]'
    prompt_len = len(prompt)
    sequences = pipeline(
        prompt,
        do_sample=True,
        top_k=10,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        max_length=2048,
    )
    response = sequences[0]['generated_text'][prompt_len:].strip()
    print(response)
