
import torch
import torch.nn as nn
import bitsandbytes as bnb
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from transformers import pipeline, BitsAndBytesConfig
import argparse
from rank_bm25 import BM25Okapi
# from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
import transformers
import json

parser = argparse.ArgumentParser(description="Parser for LoRA")
parser.add_argument('--model_name', type=str, default='mistral-7b')
parser.add_argument('--batch_size', type=int, default=2)
parser.add_argument('--k', type=int, default=1)
parser.add_argument('--max_step', type=int, default=5000)
parser.add_argument('--cut_off', type=int, default=4096)

args = parser.parse_args()
model_name = args.model_name
batch_size = args.batch_size
k = args.k
# max_step = args.max_step
cutoff_len = args.cut_off
add_eos_token = False


name2path = {
    'opt-350m': '.cache/huggingface/hub/models--facebook--opt-350m/snapshots/08ab08cc4b72ff5593870b5d527cf4230323703c',
    'opt-125m': '.cache/huggingface/hub/models--facebook--opt-125m/snapshots/27dcfa74d334bc871f3234de431e71c6eeba5dd6',
    "vicuna-7b": "vol3/models/model-llama/vicuna/models/vicuna-7b-v1.5/snapshots/de56c35b1763eaae20f4d60efd64af0a9091ebe5",
    "vicuna-13b": "vol3/models/model-llama/vicuna/models/vicuna-13b-v1.5/snapshots/3deb0106f72a3a433f0c6ea0cb978bdf14bcd3a6",
    "llama-2-7b": "vol3/models/model-llama/llama-main/hf_models/llama-2-7b",
    "llama-2-13b": "vol3/models/model-llama/llama-main/hf_models/llama-2-13b",
    "llama-2-chat-13b": "vol3/models/model-llama/llama-main/hf_models/llama-2-13b-chat",
    "llama-2-chat-7b": "vol3/models/model-llama/llama-main/hf_models/llama-2-7b-chat",
    "mistral-7b":'mistralai/Mistral-7B-Instruct-v0.2'
}

PATH = name2path[model_name]

# PATH = 'mistralai/Mistral-7B-Instruct-v0.2'

# # 4 bit quantization inference  
# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_compute_dtype=torch.float16,
#     bnb_4bit_use_double_quant=True,
#     max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
# )

# 8-bit quantization inference
# bnb_config = BitsAndBytesConfig(
#     load_in_8bit=True,
#     bnb_8bit_quant_type="nf8",
#     bnb_8bit_compute_dtype=torch.float16,
#     bnb_8bit_use_double_quant=True,
#    max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
# )

# 16-bit quantization inference
# bnb_config = BitsAndBytesConfig(
#     load_in_16bit=True,
#     bnb_16bit_quant_type="bf16",
#     bnb_16bit_compute_dtype=torch.bfloat16,
#     bnb_16bit_use_double_quant=True,
#     max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
# )

tokenizer = AutoTokenizer.from_pretrained(PATH, padding_side="left")
tokenizer.eos_token = "</s>"
tokenizer.pad_token_id = 2

base_model = AutoModelForCausalLM.from_pretrained(
    PATH,
    # quantization_config=bnb_config,
    # local_files_only=True,
    device_map='auto',
    trust_remote_code=True,
    torch_dtype=torch.bfloat16
)

base_model.config.use_cache = True
base_model.config.pad_token_id = tokenizer.pad_token_id
base_model.config.eos_token_id = tokenizer.eos_token_id
base_model.config.bos_token_id = tokenizer.bos_token_id


def split_batch(init_list, batch_size):
    groups = zip(*(iter(init_list),) * batch_size)
    end_list = [list(i) for i in groups]
    count = len(init_list) % batch_size
    end_list.append(init_list[-count:]) if count != 0 else end_list
    return end_list

def get_first_k_tokens(text, k):
    """
    Extracts the first k tokens from a text string.

    :param text: The input text string.
    :param k: The number of tokens to extract.
    :return: The first k tokens of the text string.
    """
    # Split the text into tokens based on whitespace
    tokens = text.split()
    output = " ".join(tokens[:k])

    # Return the first k tokens
    return output

with open("Private/LoRA-P-Hub/data/tweet/user_base_LLM.json", 'r') as f:
    train_data = json.load(f)

with open("Private/LoRA-P-Hub/data/tweet/user_all_test.json", 'r') as f:
    test_data = json.load(f)

template = {
    'citation': "Write a summary, in English, of the research interests and topics of a researcher who has published the following papers. Only generate the summary, no other text. User History: {} Answer:",
    'news_cat': "Look at the following past articles this journalist has written and determine the most popular category they write in. Answer in the following form: most popular category: <category>. User History: {} Answer:",
    'product_rating': "Based on this user\'s past reviews, what are the most common scores they give for positive and negative reviews? Answer in the following form: most common positive score: <most common positive score>, most common negative score: <most common negative score>. User History: {} Answer:",
    'news_headline': "Given this author\'s previous articles, try to describe a template for their headlines. I want to be able to accurately predict the headline gives one of their articles. Be specific about their style and wording, don\'t tell me anything generic. User History: {} Answer:",
    'scholarly_title': "Given this author\'s previous publications, try to describe a template for their titles. I want to be able to accurately predict the title of one of the papers from the abstract. Only generate the template description, nothing else. User History: {} Answer:",
    'tweet_paraphrase': "Given this person\'s previous tweets, try to describe a template for their tweets. I want to take a generic sentence and rephrase it to sound like one of their tweets, with the same style/punctuation/capitalization/wording/tone/etc. as them. Only give me the template description, nothing else. User History: {} Answer:",
    'movie_tagging': "Look at the following past movies this user has watched and determine the most popular tag they labeled. Answer in the following form: most popular tag: <tag>. User History: {} Answer:"
}


from tqdm import tqdm
import random

K = 40
first_k_token = 100
task = 'tweet_paraphrase'

all_out = {}
prompt_list_others = []
userid_list_others = []

for user in tqdm(train_data):

    history_list = []
    
    if len(user['profile'])> K:
        profiles = random.sample(user['profile'], K)
    else:
        profiles = user['profile']

    for p in profiles:
        history_list.append("tweet: {}".format(get_first_k_tokens(p['text'], first_k_token)))            
    history_string = ' | '.join(history_list)

    test_prompt = template[task].format(history_string)
    prompt_list_others.append(test_prompt)
    userid_list_others.append(user['user_id'])


prompt_list_100 = []
userid_list_100 = []

for user in tqdm(test_data):

    history_list = []

    if len(user['profile'])> K:
        profiles = random.sample(user['profile'], K)
    else:
        profiles = user['profile']

    for p in profiles:
        history_list.append("tweet: {}".format(get_first_k_tokens(p['text'], first_k_token)))            
    history_string = ' | '.join(history_list)

    test_prompt = template[task].format(history_string)
    prompt_list_100.append(test_prompt)
    userid_list_100.append(user['user_id'])

batched_prompt_others = split_batch(prompt_list_others, batch_size)
out_list_others = []

# print(len(prompt_list_others))
# print(len(batched_prompt_others))

with torch.no_grad():
    for batch_idx, batch in tqdm(enumerate(batched_prompt_others), total=len(batched_prompt_others)):
        sentences = batch
        inputs = tokenizer(sentences, return_tensors="pt", padding=True, return_token_type_ids=False)
        inputs = inputs.to(base_model.device)

        outputs = base_model.generate(
            **inputs,
            do_sample=True,
            top_k=10,
            temperature=0.6,
            top_p=0.9,
            eos_token_id=tokenizer.eos_token_id,
            max_new_tokens=300,
        )

        out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        # print(out_sentence)
        out_list_others += out_sentence


pred_all_others = []

for i in range(len(out_list_others)):
    output = out_list_others[i].replace(prompt_list_others[i], '')
    pred_all_others.append({
        "id": userid_list_others[i],
        "output": output
        })
    all_out[userid_list_others[i]] = output
    # print(output)


# with open('./user_profile/{}_profile_base_LLM_v2.json'.format(task), 'w') as f:
#     json.dump(pred_all_others, f)


batched_prompt_100 = split_batch(prompt_list_100, batch_size)
out_list_100 = []

with torch.no_grad():
    for batch_idx, batch in tqdm(enumerate(batched_prompt_100), total=len(batched_prompt_100)):
        sentences = batch
        inputs = tokenizer(sentences, return_tensors="pt", padding=True, return_token_type_ids=False)
        inputs = inputs.to(base_model.device)

        outputs = base_model.generate(
            **inputs,
            do_sample=True,
            top_k=10,
            temperature=0.6,
            top_p=0.9,
            eos_token_id=tokenizer.eos_token_id,
            max_new_tokens=300,
        )

        out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        print(out_sentence)
        out_list_100 += out_sentence
        
pred_all_100 = []

for i in range(len(out_list_100)):
    output = out_list_100[i].replace(prompt_list_100[i], '')
    pred_all_100.append({
        "id": userid_list_100[i],
        "output": output
        })
    all_out[userid_list_100[i]] = output
    # print(output)


with open('./user_profile/all_profile_id2text-v2.json'.format(task), 'w') as f:
    json.dump(all_out, f)

