import os
import torch
from datasets import load_dataset
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging
)
from peft import LoraConfig
from trl import SFTTrainer
import transformers
import evaluate
import json
from tqdm import tqdm
from huggingface_hub import login


def inference(test_data, test_labels):
    output_dir = "./results_financial_phrasebank"
    # Where to load model results
    final_checkpoint_dir = os.path.join(output_dir, "final_merged_checkpoint")

    # Load the entire model on the GPU 0
    device_map = {"": 0}

    # Set base model loading in 4-bits
    use_4bit = True

    # Compute dtype for 4-bit base models
    bnb_4bit_compute_dtype = torch.bfloat16

    # Quantization type (fp4 or nf4)
    bnb_4bit_quant_type = "nf4"

    # Activate nested quantization for 4-bit base models (double quantization)
    use_nested_quant = True

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=use_4bit,
        bnb_4bit_quant_type=bnb_4bit_quant_type,
        bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
        bnb_4bit_use_double_quant=use_nested_quant
    )
    # Load the base model
    model = AutoModelForCausalLM.from_pretrained(
        final_checkpoint_dir,
        device_map=device_map,
        quantization_config=bnb_config,
        return_dict=True,
        low_cpu_mem_usage=True,
    )
    model.config.use_cache = False
    model.config.pretraining_tp = 1

    tokenizer = AutoTokenizer.from_pretrained(final_checkpoint_dir)

    # Define a custom padding token
    tokenizer.pad_token = "<PAD>"

    # Set the padding direction to the right
    tokenizer.padding_side = "right"

    batch_size = 32

    pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_new_tokens=250, clean_up_tokenization_spaces=True)
    predictions = []
    print("Predicting test data...")
    for el in tqdm(test_data, total=len(test_data)):
        response = pipe(el["text"])
        predictions.append(response[0]["generated_text"].split("=")[-1].strip())
    hit = 0
    for i in range(len(predictions)):
        if predictions[i] == test_labels[i]:
            hit += 1
    print(f"Got {hit} of {len(predictions)} correct which equals accuracy of {hit/len(predictions)}")
    with open(os.path.join("/data", "predictions.json"), "w") as file:
        json.dump(predictions, file)
    with open(os.path.join("/data", "test_set.json"), "w") as file:
        json.dump(test_labels, file)

def generate_prompt(sentence, sentiment):
    return f"""
        Analyze the sentiment of the news headline enclosed in square brackets, 
        determine if it is positive, neutral, or negative, and return the answer as 
        the corresponding sentiment label "positive" or "neutral" or "negative".

        [{sentence}] = {sentiment}
        """.strip()

def generate_test_prompt(sentence):
    return f"""
            Analyze the sentiment of the news headline enclosed in square brackets, 
            determine if it is positive, neutral, or negative, and return the answer as 
            the corresponding sentiment label "positive" or "neutral" or "negative".

            [{sentence}] = """.strip()


def rework_label(label):
    if label == 0:
        return "negative"
    elif label == 1:
        return "neutral"
    elif label == 2:
        return "positive"

def get_proper_data_format(data):
    data = data["train"].train_test_split(test_size=0.2)
    test_set = data["test"]
    data = data["train"].train_test_split(test_size=0.2)
    train_set = data["train"]
    validation_set = data["test"]
    final_train = []
    final_validation = []
    final_test = []
    final_test_labels = []
    for el in train_set:
        sentiment = rework_label(el["label"])
        new_el = {
            "text": generate_prompt(el["sentence"], sentiment)
        }
        final_train.append(new_el)
    for el in validation_set:
        sentiment = rework_label(el["label"])
        new_el = {
            "text": generate_prompt(el["sentence"], sentiment)
        }
        final_validation.append(new_el)
    for el in test_set:
        sentiment = rework_label(el["label"])
        new_el = {
            "text": generate_test_prompt(el["sentence"])
        }
        final_test.append(new_el)
        final_test_labels.append(sentiment)

    return final_train, final_validation, final_test, final_test_labels

if __name__ == "__main__":
    data = load_dataset("financial_phrasebank", 'sentences_50agree')
    train, validation, test, test_labels = get_proper_data_format(data)
    print("Adapted data format...")
    inference(test, test_labels)