import torch
import torch.nn as nn
import bitsandbytes as bnb
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from transformers import pipeline, BitsAndBytesConfig
import argparse
import transformers
import random


parser = argparse.ArgumentParser(description="Parser for LoRA")
parser.add_argument('--model_name', type=str, default='llama-2-7b')
parser.add_argument('--batch_size', type=int, default=3)
parser.add_argument('--k', type=int, default=1)
parser.add_argument('--max_step', type=int, default=5000)
parser.add_argument('--cut_off', type=int, default=2048)
parser.add_argument('--max_epoch', type=int, default=2)

args = parser.parse_args()
task_name = "news_cat"
model_name = args.model_name
batch_size = args.batch_size
k = args.k
# max_step = args.max_step
cutoff_len = args.cut_off
add_eos_token = False
max_epoch = args.max_epoch

name2path = {
    'opt-350m': '',
    'opt-125m': '',
    "vicuna-7b": "",
    "vicuna-13b": "",
    "llama-2-7b": "",
    "llama-2-13b": "",
    "llama-2-chat-13b": "",
    "llama-2-chat-7b": "",
}


PATH = name2path[model_name]

# # 4 bit quantization inference  
# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_compute_dtype=torch.float16,
#     bnb_4bit_use_double_quant=True,
#     max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
# )

# 8-bit quantization inference
# bnb_config = BitsAndBytesConfig(
#     load_in_8bit=True,
#     bnb_8bit_quant_type="nf8",
#     bnb_8bit_compute_dtype=torch.float16,
#     bnb_8bit_use_double_quant=True,
#    max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
# )

# 16-bit quantization inference
# bnb_config = BitsAndBytesConfig(
#     load_in_16bit=True,
#     bnb_16bit_quant_type="bf16",
#     bnb_16bit_compute_dtype=torch.bfloat16,
#     bnb_16bit_use_double_quant=True,
#     max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
# )

tokenizer = AutoTokenizer.from_pretrained(PATH, padding_side="left")
tokenizer.eos_token = "</s>"
tokenizer.pad_token = '[PAD]'

base_model = AutoModelForCausalLM.from_pretrained(
    PATH,
    # quantization_config=bnb_config,
    local_files_only=True,
    device_map='auto',
    trust_remote_code=True,
    torch_dtype=torch.bfloat16
)

base_model.config.use_cache = False
base_model.config.pad_token_id = tokenizer.pad_token_id
base_model.config.eos_token_id = tokenizer.eos_token_id
base_model.config.bos_token_id = tokenizer.bos_token_id

from peft import prepare_model_for_kbit_training

base_model.gradient_checkpointing_enable()
base_model = prepare_model_for_kbit_training(base_model)

def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

from peft import LoraConfig, get_peft_model 

peft_config = LoraConfig(
    r=8,
    lora_alpha=8,
    target_modules=["q_proj", "v_proj", "k_proj", "out_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

training_arguments = transformers.TrainingArguments(
    output_dir='outputs/',
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=1,
    optim='adamw_torch',
    num_train_epochs=max_epoch,
    save_steps=1e9,
    logging_steps=50,
    learning_rate=1e-4,
    weight_decay=1e-2,
    bf16=True,
    max_grad_norm=0.3,
    # max_steps=max_step,
    warmup_ratio=0.1,
    group_by_length=True,
    lr_scheduler_type='linear',
    report_to='none',
)

import json
from tqdm import tqdm

with open("/afs//Private/LoRA-P-new/LaMP_processed/product_rat/user_others_old.json", 'r') as f:
    train = json.load(f)

with open('/afs//Private/LoRA-P-new/LaMP_processed/product_rat/user_more_100_history.json', 'r') as f:
    test_data = json.load(f)

all_profiles = []

for user in tqdm(train):
    all_profiles = all_profiles + user['profile']

def extract_review(text):
    marker = "without further explanation. review: "
    # Find the position of the marker in the text
    marker_pos = text.find(marker)
    
    # Check if the marker is found
    if marker_pos == -1:
        raise ValueError()

    # Extract the string after the marker
    extracted_string = text[marker_pos + len(marker):]

    return extracted_string


def split_batch(init_list, batch_size):
    groups = zip(*(iter(init_list),) * batch_size)
    end_list = [list(i) for i in groups]
    count = len(init_list) % batch_size
    end_list.append(init_list[-count:]) if count != 0 else end_list
    return end_list


def get_first_k_tokens(text, k):
    """
    Extracts the first k tokens from a text string.

    :param text: The input text string.
    :param k: The number of tokens to extract.
    :return: The first k tokens of the text string.
    """
    # Split the text into tokens based on whitespace
    tokens = text.split()
    output = " ".join(tokens[:k])

    # Return the first k tokens
    return output

def tokenize(prompt, add_eos_token=True):
    # there's probably a way to do this with the tokenizer settings
    # but again, gotta move fast
    result = tokenizer(
        prompt,
        truncation=True,
        max_length=cutoff_len,
        padding=False,
        return_tensors=None,
    )
    if (
        result["input_ids"][-1] != tokenizer.eos_token_id
        and len(result["input_ids"]) < cutoff_len
        and add_eos_token
    ):
        result["input_ids"].append(tokenizer.eos_token_id)
        result["attention_mask"].append(1)

    result["labels"] = result["input_ids"].copy()

    return result

def generate_and_tokenize_prompt(data_point):
    full_prompt = data_point['full_prompt']
    tokenized_full_prompt = tokenize(full_prompt)
    # if not train_on_inputs:
    user_prompt = data_point['prompt']
    
    tokenized_user_prompt = tokenize(
        user_prompt, add_eos_token=add_eos_token
    )
    user_prompt_len = len(tokenized_user_prompt["input_ids"])
    
    # print(user_prompt_len)

    if add_eos_token:
        user_prompt_len -= 1

    tokenized_full_prompt["labels"] = [
        -100
    ] * user_prompt_len + tokenized_full_prompt["labels"][
        user_prompt_len:
    ]  # could be sped up, probably
    return tokenized_full_prompt

# training
from datasets import load_dataset, Dataset

pred_all = []
actual = []
model = get_peft_model(base_model, peft_config)
print_trainable_parameters(model)

train_data = []
for i in tqdm(range(len(train))):
    selected_history = random.sample(all_profiles, 1)[0]
    history_string = "review: {} score: {}".format(selected_history['text'], selected_history['score'])

    for q in train[i]['query']:

        prompt = '{}\nWhat is the score of the following review on a scale of 1 to 5? just answer with 1, 2, 3, 4, or 5 without further explanation. review: {}\n score:'.format(history_string, get_first_k_tokens(extract_review(q['input']), 512))
        full_prompt = '{}\nWhat is the score of the following review on a scale of 1 to 5? just answer with 1, 2, 3, 4, or 5 without further explanation. review: {}\n score: {}'.format(history_string, get_first_k_tokens(extract_review(q['input']), 512), q['gold'])
        
        train_data.append(
            {
                "prompt": prompt,
                "full_prompt": full_prompt
            }
        )

train_dataset = Dataset.from_list(train_data)
train_dataset = train_dataset.shuffle().map(generate_and_tokenize_prompt)

trainer = transformers.Trainer(
    model=model,
    train_dataset=train_dataset,
    args=training_arguments,
    data_collator=transformers.DataCollatorForSeq2Seq(
            tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
    ),
)

for name, module in trainer.model.named_modules():
    if "norm" in name:
        module = module.to(torch.float32)


model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()

output_name = "LoRA-random/{}_{}_non_personalized_LoRA_ckpt".format(task_name, model_name)
model.save_pretrained('./ckpt/{}'.format(output_name))

model.eval()
model.config.use_cache = True  # silence the warnings. Please re-enable for inference!
for i in tqdm(range(len(test_data))):
    test_question_list = []
    question_id_list = []

    selected_history = random.sample(all_profiles, 1)[0]
    history_string = "review: {} score: {}".format(selected_history['text'], selected_history['score'])

    for q in test_data[i]['query']:
        test_question = q['input']

        test_article = extract_review(test_question)
        test_prompt = '{}\nWhat is the score of the following review on a scale of 1 to 5? just answer with 1, 2, 3, 4, or 5 without further explanation. review: {}\n score:'.format(history_string, test_article)
        test_question_list.append(test_prompt)
        question_id_list.append(q['id'])

    test_batch_list = split_batch(test_question_list, 1)
    out_list = []

    with torch.no_grad():
        for batch_idx, batch in tqdm(enumerate(test_batch_list), total=len(test_batch_list)):
            try:
                sentences = batch
                inputs = tokenizer(sentences, return_tensors="pt", padding=True, return_token_type_ids=False)
                inputs = inputs.to(model.device)
                with torch.autocast(device_type="cuda"):
                    outputs = model.generate(
                        **inputs,
                        do_sample=True,
                        top_k=10,
                        temperature=0.1,
                        top_p=0.9,
                        eos_token_id=tokenizer.eos_token_id,
                        max_length=3096,
                    )

                out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
                out_list += out_sentence
            except:
                out_list += ['']

    for i in range(len(out_list)):
        output = out_list[i].replace(test_question_list[i], '')
        pred_all.append({
            "id": question_id_list[i],
            "output": output
            })
    
        print(output)
        
output_file = {
    'task': 'LaMP_3',
    'golds': pred_all,
    'model': model_name,
}

with open('./outputs/output_LoRA-random.json', 'w') as f:
    json.dump(output_file, f, indent=4)