from transformers import AutoModelForSequenceClassification
import torch
from transformers import AutoTokenizer, pipeline, T5ForConditionalGeneration, BitsAndBytesConfig, AutoModelForCausalLM, TrainingArguments
from tqdm import tqdm
import json
from datasets import Dataset, load_dataset
from typing import Dict, Optional
import os
from trl import DPOTrainer
from huggingface_hub import login
from peft import LoraConfig, prepare_model_for_kbit_training


# test_file_path = "./data/complete_qa_final.json"
# with open(test_file_path, "r") as input_file:
#     data = json.load(input_file)
    
# # For testing only use first 100 answers
# data_subset = data[:1000]
# texts = [el["answer_1"] for el in data_subset]

# hatebert = pipeline("text-classification", model="GroNLP/hateBERT", device=0)
# tokenizer_kwargs = {'padding':True,'truncation':True,'max_length':512}


# predictions = []
# for el in tqdm(texts, total=len(texts)):
#     response = hatebert(el, **tokenizer_kwargs)
#     predictions.append(response[0]["label"])

# with open("./results/hatebert_result.json", "w") as output_file:
#     json.dump(predictions, output_file)

# def get_proper_data_format(dataset):
#     dataset_length = len(dataset)
#     validation_size = int(0.1 * dataset_length)
#     test_size = 500
#     validation_and_test_size = validation_size + test_size
#     train_dataset = dataset[:len(dataset) - validation_and_test_size]
#     validation_dataset = dataset[len(dataset) - validation_and_test_size:len(dataset) - test_size]
#     test_dataset = dataset[len(dataset) - test_size:]
#     final_train = []
#     final_validation = []
#     for el in train_dataset:
#         new_el = {
#             "prompt": f"[INST] {el['text']} {el['context']} [/INST]",
#             "chosen": el['answer_1'],
#             "rejected": el['answer_2'],
#         }
#         final_train.append(new_el)
#     for el in validation_dataset:
#         new_el = {
#             "prompt": f"[INST] {el['text']} {el['context']} [/INST]",
#             "chosen": el['answer_1'],
#             "rejected": el['answer_2'],
#         }
#         final_validation.append(new_el)

#     with open("./data/complete_qa_final_filtered_preprocessed_dpo_train.json", "w") as file:
#         print(f"training set length: {len(final_train)}")
#         json.dump(final_train, file)
#     # with open(os.path.join(data_path, "complete_qa_final_filtered_test.json"), "w") as file:
#     #     print(f"test set length: {len(test_dataset)}")
#     #     json.dump(test_dataset, file)
#     with open("./data/complete_qa_final_filtered_preprocessed_dpo_validation.json", "w") as file:
#         print(f"validation set length: {len(final_validation)}")
#         json.dump(final_validation, file)
#     return final_train, final_validation

# def extract_anthropic_prompt(prompt_and_response):
#     """Extract the anthropic prompt from a prompt and response pair."""
#     search_term = "\n\nAssistant:"
#     search_term_idx = prompt_and_response.rfind(search_term)
#     assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
#     return prompt_and_response[: search_term_idx + len(search_term)]


# def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset:
#     """Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.

#     The dataset is converted to a dictionary with the following structure:
#     {
#         'prompt': List[str],
#         'chosen': List[str],
#         'rejected': List[str],
#     }

#     Prompts should be structured as follows:
#       \n\nHuman: <prompt>\n\nAssistant:
#     Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
#     """
#     dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir)
#     if sanity_check:
#         dataset = dataset.select(range(min(len(dataset), 1000)))

#     def split_prompt_and_responses(sample) -> Dict[str, str]:
#         prompt = extract_anthropic_prompt(sample["chosen"])
#         return {
#             "prompt": prompt,
#             "chosen": sample["chosen"][len(prompt) :],
#             "rejected": sample["rejected"][len(prompt) :],
#         }

#     return dataset.map(split_prompt_and_responses)

# with open("./data/complete_qa_final_filtered.json", "r") as f:
#     data = json.load(f)
# train, validation = get_proper_data_format(data)
# train = Dataset.from_list(train)
# validation = Dataset.from_list(validation)
# # 2. Load the Anthropic Helpful-Harmless dataset
# train_dataset = get_hh("train", sanity_check=True)

# # 3. Load evaluation dataset
# eval_dataset = get_hh("test", sanity_check=True)

# from transformers import T5Tokenizer, T5ForConditionalGeneration, TrainingArguments
# from trl import SFTTrainer
# from datasets import load_dataset
# from peft import LoraConfig

# def chars_token_ratio(dataset, tokenizer, nb_examples=400):
#     """
#     Estimate the average number of characters per token in the dataset.
#     """
#     total_characters, total_tokens = 0, 0
#     for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
#         total_characters += len(example["text"])
#         if tokenizer.is_fast:
#             total_tokens += len(tokenizer(example["text"]).tokens())
#         else:
#             total_tokens += len(tokenizer.tokenize(example["text"]))

#     return total_characters / total_tokens

# Load your dataset
# data_path = "./data"
# with open(os.path.join(data_path, "complete_qa_final_filtered_preprocessed_train.json"), "r") as data_file:
#     train = json.load(data_file)
# with open(os.path.join(data_path, "complete_qa_final_filtered_preprocessed_validation.json"), "r") as data_file:
#     validation = json.load(data_file)

# # Define your T5 model and tokenizer
# model = T5ForConditionalGeneration.from_pretrained("t5-small")
# tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta", , trust_remote_code=True, use_fast=True)

# chars_per_token = chars_token_ratio(train, tokenizer)

# # LoRA attention dimension
# lora_r = 64
# # Alpha for LoRA scaling
# lora_alpha = 128
# # Dropout probability for LoRA
# lora_dropout = 0.05

# # Create the LoRA configuration
# peft_config = LoraConfig(
#     r=lora_r,
#     lora_alpha=lora_alpha,
#     lora_dropout=lora_dropout,
#     inference_mode=False,
#     bias="none",
#     task_type="CAUSAL_LM",
#     # target_modules=[
#     #     "gate_proj",
#     #     "down_proj",
#     #     "up_proj",
#     #     "q_proj",
#     #     "v_proj",
#     #     "k_proj",
#     #     "o_proj",
#     # ]
# )

# output_dir = "./test"
# num_train_epochs = 1
# max_steps = -1
# bf16 = True
# fp16 = False
# batch_size = 4
# gradient_accumulation_steps = 16
# max_grad_norm = 0.3
# optim = "paged_adamw_8bit"
# learning_rate = 1e-4
# lr_scheduler_type = "cosine"
# warmup_ratio = 0.05
# weight_decay = 0.05
# gradient_checkpointing = True
# save_steps = 30
# logging_steps = 11
# eval_steps = 30
# training_arguments = TrainingArguments(
#     output_dir=output_dir,
#     dataloader_drop_last=True,
#     evaluation_strategy="steps",
#     per_device_train_batch_size=batch_size,
#     per_device_eval_batch_size=batch_size,
#     gradient_accumulation_steps=gradient_accumulation_steps,
#     gradient_checkpointing=gradient_checkpointing,
#     optim=optim,
#     save_steps=save_steps,
#     logging_steps=logging_steps,
#     eval_steps=eval_steps,
#     learning_rate=learning_rate,
#     num_train_epochs=num_train_epochs,
#     weight_decay=weight_decay,
#     fp16=fp16,
#     bf16=bf16,
#     max_grad_norm=max_grad_norm,
#     max_steps=max_steps,
#     warmup_ratio=warmup_ratio,
#     lr_scheduler_type=lr_scheduler_type,
#     run_name="llama-7b-finetuned",
#     report_to="wandb",
#     ddp_find_unused_parameters=False,
#     save_total_limit=5,
#     load_best_model_at_end=True,
#     # auto_find_batch_size=True,
#     gradient_checkpointing_kwargs={
#         "use_reentrant": False
#     }
# )

# # Define your TRL SFTTrainer
# sft_trainer = SFTTrainer(
#     model=model,
#     train_dataset=train,
#     eval_dataset=validation,
#     peft_config=peft_config,
#     dataset_text_field="text",
#     tokenizer=tokenizer,
#     args=training_arguments,
#     packing=True,
#     max_seq_length=2048,
#     chars_per_token=chars_per_token,
#     neftune_noise_alpha=10.0
# )

# # Start debugging

# # 1. Set breakpoints
# import pdb
# pdb.set_trace()

# # 2. Print statements
# print(sft_trainer.train_dataset)

# # 3. Logging
# import logging
# logging.basicConfig(level=logging.INFO)
# logger = logging.getLogger(__name__)
# logger.info("Debugging information")

# # Start training
# sft_trainer.train()

def removeElements(A, B):
     
    for i in range(len(B)-len(A)+1):
        for j in range(len(A)):
            if B[i + j] != A[j]:
                break
        else:
            return True, i
     
    return False, -1

with open("tokenized_train_data_unpacked_correct.json", "r") as inputfile:
    unpacked = json.load(inputfile)

with open("tokenized_train_data_correct.json", "r") as inputfile:
    correct_tokens = json.load(inputfile)

with open("tokenized_train_data_zephyr_unpacked_correct.json", "r") as inputfile:
    correct_tokens_zephyr = json.load(inputfile)


complete_sequence_counts = []
for el in tqdm(unpacked["attention_mask"]):
    complete_sequence_counts.append(sum(el))
    # detected_begin = False
    # sequence_count = 0
    # for token in el:
    #     if token == 1:
    #         detected_begin = True
    #     if detected_begin and token == 2:
    #         detected_begin = False
    #         sequence_count += 1
    # complete_sequence_counts.append(sequence_count)
sum_sequence_counts = sum(complete_sequence_counts)
# complete_sequence_counts = []
# for el in tqdm(correct_tokens["input_ids"]):
#     detected_begin = False
#     sequence_count = 0
#     for token in el:
#         if token == 1:
#             detected_begin = True
#         if detected_begin and token == 2:
#             detected_begin = False
#             sequence_count += 1
#     complete_sequence_counts.append(sequence_count)
# sum_sequence_counts = sum(complete_sequence_counts)
with open("tokenized_train_data_dpo.json", "r") as inputfile:
    tokenized_train = json.load(inputfile)

with open("tokenized_train_data_dpo_llama.json", "r") as inputfile:
    tokenized_train_llama = json.load(inputfile)

with open("./data/complete_qa_final_filtered.json", "r") as input_file:
    data = json.load(input_file)

with open("./data/complete_qa_final_filtered_preprocessed_train.json", "r") as input_file:
    dataset = json.load(input_file)


HF_TOKEN = "token"
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", use_auth_token=HF_TOKEN, add_eos_token=True, use_fast=True)

complete = ["[INST] " + el["text"] + " " + el["context"] + " [/INST] " + el["answer_1"] for el in data]

tokenized_single = tokenizer(complete[24789])
tokenized_single_new = [29889, 29900, 29941, 13667, 315, 29940, 29979, 1551, 845, 487, 1623, 13, 29900, 29889]
for idx, el in tqdm(enumerate(tokenized_train["input_ids"]), total=len(tokenized_train["input_ids"])):
    test  = 1
    is_sublist, found_idx = removeElements(tokenized_single["input_ids"][:10], el)
    if is_sublist:
        x = 1



# def print_trainable_parameters(model):
#   """
#   Prints the number of trainable parameters in the model.
#   """
#   trainable_params = 0
#   all_param = 0
#   for _, param in model.named_parameters():
#     all_param += param.numel()
#     if param.requires_grad:
#       trainable_params += param.numel()
#   print(
#       f"trainable params: {trainable_params} || all params: {all_param} || trainables%: {100 * trainable_params / all_param}"
#   )


# def find_all_linear_names(model):
#     cls = bnb.nn.Linear4bit
#     lora_module_names = set()
#     for name, module in model.named_modules():
#         if isinstance(module, cls):
#             names = name.split('.')
#             lora_module_names.add(names[0] if len(names) == 1 else names[-1])
#     print(list(lora_module_names))
#     return list(lora_module_names)

# login()
# data_path = "/data"

# output_dir="./results_7b_llama_sft_dpo_final_20_01_24"
# final_checkpoint_dir = os.path.join(output_dir, "final_checkpoint")

# train_dataset = load_dataset("json", data_files="data/dpo_train_tokenized_filter.json",split="train")

# validation_dataset = load_dataset("json", data_files="data/dpo_validation_tokenized_filter.json",split="train")

# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_compute_dtype=torch.bfloat16,
# )

# model = T5ForConditionalGeneration.from_pretrained("t5-small")

# # model.config.use_cache = False
# # model = prepare_model_for_kbit_training(model)

# tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
# tokenizer.pad_token = tokenizer.eos_token

# def return_prompt_and_responses(samples):
#     return {
#         "prompt": samples["prompt"],
#         "chosen": samples["chosen"],
#         "rejected": samples["rejected"],
#     }
# print("Formatting data..")
# original_train_columns = train_dataset.column_names

# train_dataset = train_dataset.map(
#     return_prompt_and_responses,
#     batched=True,
#     remove_columns=original_train_columns
# )

# original_validation_columns = validation_dataset.column_names

# validation_dataset = validation_dataset.map(
#     return_prompt_and_responses,
#     batched=True,
#     remove_columns=original_validation_columns
# )

# training_args = TrainingArguments(
#     per_device_train_batch_size=2,
#     per_device_eval_batch_size=2,
#     evaluation_strategy="steps",
#     gradient_accumulation_steps=32,
#     gradient_checkpointing=True,
#     max_grad_norm= 0.3,
#     num_train_epochs=1,
#     max_steps=-1, 
#     save_steps=100,
#     eval_steps=100,
#     learning_rate=5e-4,
#     bf16=True,
#     save_total_limit=5,
#     logging_steps=50,
#     output_dir=output_dir,
#     optim="paged_adamw_8bit",
#     lr_scheduler_type="cosine",
#     warmup_ratio=0.05,
#     remove_unused_columns=False,
#     load_best_model_at_end=True,
#     dataloader_drop_last=True,
#     gradient_checkpointing_kwargs={
#         "use_reentrant": False
#     }
# )

# peft_config = LoraConfig(
#     r=16,
#     lora_alpha=32,
#     # target_modules=find_all_linear_names(model),
#     lora_dropout=0.05,
#     bias="none",
#     task_type="CAUSAL_LM",
# )

# dpo_trainer = DPOTrainer(
#     model,
#     args=training_args,
#     beta=0.1,
#     train_dataset=train_dataset,
#     eval_dataset=validation_dataset,
#     tokenizer=tokenizer,
#     peft_config=peft_config,
#     max_prompt_length=512,
#     max_length=1024,
    
# )
# cur_tokenized = dpo_trainer.train_dataset.to_dict()

# with open("tokenized_train_data_dpo_llama.json", "w") as out_file:
#     json.dump(cur_tokenized, out_file)