
from transformers import pipeline
import json
from tqdm import tqdm
import os
from collections import Counter
from huggingface_hub import login
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging
)
import statistics

def tokenized_input_fits(el, tokenizers, max_length):
    qa_good = f"[INST] {el['text']} {el['context']} [/INST] {el['answer_1']}"
    qa_bad = f"[INST] {el['text']} {el['context']} [/INST] {el['answer_2']}"
    for idx, tokenizer in enumerate(tokenizers):
        # if idx == 3:
        #     qa_good = f"<|system|>\n</s>\n<|user|>\n{el['text']} {el['context']}</s>\n<|assistant|>\n{el['answer_1']}</s>"
        #     qa_bad = f"<|system|>\n</s>\n<|user|>\n{el['text']} {el['context']}</s>\n<|assistant|>\n{el['answer_2']}</s>"
        qa_good_tokenized = tokenizer(qa_good)
        if len(qa_good_tokenized["input_ids"]) > max_length:
            return False
        
        qa_bad_tokenized = tokenizer(qa_bad)
        if len(qa_bad_tokenized["input_ids"]) > max_length:
            return False
    return True

def check_tokenized_lengths(el, tokenizer, attribute):
    text = el[attribute]
    tokenized = tokenizer(text)
    return len(tokenized["input_ids"])


def get_proper_data_format(dataset, tokenizers, max_length, validation_size, test_size):
    with open("./hate_classification_deprecated/hate_classifier_answer_1.json", "r") as inputfile:
        hate_classified = json.load(inputfile)
    
    hate_classified = [el["idx"] for el in hate_classified if el["reason"] == "OFFENSIVE-LANGUAGE"]
    print(f"Initial Dataset has {len(dataset)} entries")
    final_dataset = []
    text_lengths = []
    context_lengths = []
    answer_1_lengths = []
    answer_2_lengths = []
    qa_good_lengths = []
    qa_bad_lengths = []
    for idx, el in tqdm(enumerate(dataset), total=len(dataset)):
        if idx in hate_classified:
            continue
        if not tokenized_input_fits(el, tokenizers, max_length):
            continue
        final_dataset.append(el)

        # text_lengths.append(check_tokenized_lengths(el, tokenizers[0], "text"))
        # context_lengths.append(check_tokenized_lengths(el, tokenizers[0], "context"))
        # answer_1_lengths.append(check_tokenized_lengths(el, tokenizers[0], "answer_1"))
        # answer_2_lengths.append(check_tokenized_lengths(el, tokenizers[0], "answer_2"))
    #     qa_good = f"[INST] {el['text']} {el['context']} [/INST] {el['answer_1']}"
    #     qa_bad = f"[INST] {el['text']} {el['context']} [/INST] {el['answer_2']}"
    #     qa_good_lengths.append(len(tokenizers[0](qa_good)["input_ids"]))
    #     qa_bad_lengths.append(len(tokenizers[0](qa_bad)["input_ids"]))
        
    # text_average = sum(text_lengths) / len(text_lengths)
    # context_average = sum(context_lengths) / len(context_lengths)
    # answer_1_average = sum(answer_1_lengths) / len(answer_1_lengths)
    # answer_2_average = sum(answer_2_lengths) / len(answer_2_lengths)
    # qa_good_average = sum(qa_good_lengths) / len(qa_good_lengths)
    # qa_bad_average = sum(qa_bad_lengths) / len(qa_bad_lengths)

    # text_median = statistics.median(text_lengths)
    # context_median = statistics.median(context_lengths)
    # answer_1_median = statistics.median(answer_1_lengths)
    # answer_2_median = statistics.median(answer_2_lengths)
    # qa_good_median = statistics.median(qa_good_lengths)
    # qa_bad_median = statistics.median(qa_bad_lengths)


    dataset_length = len(final_dataset)
    print(f"Dataset with tokenized tuples smaller than {max_length} has {dataset_length} entries")

    validation_and_test_size = validation_size + test_size
    train_dataset = final_dataset[:len(final_dataset) - validation_and_test_size]
    validation_dataset = final_dataset[len(final_dataset) - validation_and_test_size:len(final_dataset) - test_size]
    test_dataset = final_dataset[len(final_dataset) - test_size:]
    sft_train = []
    sft_train_zephyr = []
    dpo_train = []
    dpo_train_zephyr = []
    sft_validation = []
    sft_validation_zephyr = []
    dpo_validation = []
    dpo_validation_zephyr = []
    for el in train_dataset:
        new_sft_el = {
            "text": f"[INST] {el['text']} {el['context']} [/INST] {el['answer_1']}"
        }
        sft_train.append(new_sft_el)
        new_zephyr_sft_el = {
            "text": f"<|system|>\n</s>\n<|user|>\n{el['text']} {el['context']}</s>\n<|assistant|>\n{el['answer_1']}</s>\n"
        }
        sft_train_zephyr.append(new_zephyr_sft_el)

        new_dpo_el = {
            "prompt": f"[INST] {el['text']} {el['context']} [/INST]",
            "chosen": el['answer_1'],
            "rejected": el['answer_2'],
        }
        dpo_train.append(new_dpo_el)
        new_zephyr_dpo_el = {
            "prompt": f"<|system|>\n</s>\n<|user|>\n{el['text']} {el['context']}</s>\n",
            "chosen": f"<|assistant|>\n{el['answer_1']}</s>\n",
            "rejected": f"<|assistant|>\n{el['answer_2']}</s>\n",
        }
        dpo_train_zephyr.append(new_zephyr_dpo_el)
    for el in validation_dataset:
        new_sft_el = {
            "text": f"[INST] {el['text']} {el['context']} [/INST] {el['answer_1']}"
        }
        sft_validation.append(new_sft_el)
        new_zephyr_sft_el = {
            "text": f"<|system|>\n</s>\n<|user|>\n{el['text']} {el['context']}</s>\n<|assistant|>\n{el['answer_1']}</s>\n"
        }
        sft_validation_zephyr.append(new_zephyr_sft_el)

        new_dpo_el = {
            "prompt": f"[INST] {el['text']} {el['context']} [/INST]",
            "chosen": el['answer_1'],
            "rejected": el['answer_2'],
        }
        dpo_validation.append(new_dpo_el)
        new_zephyr_dpo_el = {
            "prompt": f"<|system|>\n</s>\n<|user|>\n{el['text']} {el['context']}</s>\n",
            "chosen": f"<|assistant|>\n{el['answer_1']}</s>\n",
            "rejected": f"<|assistant|>\n{el['answer_2']}</s>\n",
        }
        dpo_validation_zephyr.append(new_zephyr_dpo_el)
    
    with open(os.path.join(data_path, "sft_train_tokenized_filter.json"), "w") as file:
        print(f"SFT training set length: {len(sft_train)}")
        json.dump(sft_train, file)
    with open(os.path.join(data_path, "sft_validation_tokenized_filter.json"), "w") as file:
        print(f"SFT validation set length: {len(sft_validation)}")
        json.dump(sft_validation, file)
    with open(os.path.join(data_path, "dpo_train_tokenized_filter.json"), "w") as file:
        print(f"DPO training set length: {len(dpo_train)}")
        json.dump(dpo_train, file)
    with open(os.path.join(data_path, "dpo_validation_tokenized_filter.json"), "w") as file:
        print(f"DPO validation set length: {len(dpo_validation)}")
        json.dump(dpo_validation, file)

    with open(os.path.join(data_path, "sft_zephyr_train_tokenized_filter.json"), "w") as file:
        print(f"SFT Zephyr training set length: {len(sft_train_zephyr)}")
        json.dump(sft_train_zephyr, file)
    with open(os.path.join(data_path, "sft_zephyr_validation_tokenized_filter.json"), "w") as file:
        print(f"SFT Zephyr validation set length: {len(sft_validation_zephyr)}")
        json.dump(sft_validation_zephyr, file)
    with open(os.path.join(data_path, "dpo_zephyr_train_tokenized_filter.json"), "w") as file:
        print(f"DPO Zephyr training set length: {len(dpo_train_zephyr)}")
        json.dump(dpo_train_zephyr, file)
    with open(os.path.join(data_path, "dpo_zephyr_validation_tokenized_filter.json"), "w") as file:
        print(f"DPO Zephyr validation set length: {len(dpo_validation_zephyr)}")
        json.dump(dpo_validation_zephyr, file)
    with open(os.path.join(data_path, "test_tokenized_filter.json"), "w") as file:
        print(f"test set length: {len(test_dataset)}")
        json.dump(test_dataset, file)
    with open(os.path.join(data_path, "final_data_tokenized_filter.json"), "w") as file:
        print(f"Complete dataset length: {len(final_dataset)}")
        json.dump(final_dataset, file)

login()
old_data_path = "./data_deprecated"
data_path = "./data_new"
with open(os.path.join(old_data_path, "complete_qa_final_filtered.json"), "r") as f:
    data = json.load(f)
llama_base_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", use_fast=True)
llama_chat_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", use_fast=True)
mistral_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True)
zephyr_tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta", use_fast=True)
tokenizers = [llama_base_tokenizer, llama_chat_tokenizer, mistral_tokenizer, zephyr_tokenizer]
print("Getting proper data format..")
get_proper_data_format(data, tokenizers, max_length=1023, validation_size=1000, test_size=500)