from dataclasses import dataclass, field
from typing import Optional
from transformers import LlamaTokenizerFast
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoModelForCausalLM
from accelerate import Accelerator
from datasets import load_dataset
from peft import LoraConfig
from peft import PeftModel
from tqdm import tqdm
from transformers import Adafactor, AutoTokenizer, HfArgumentParser, pipeline, BertForSequenceClassification, BertTokenizer
import wandb
import pandas as pd
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed
from trl.core import LengthSampler
from nltk.tokenize import word_tokenize
import os
import argparse


@dataclass
class ScriptArguments:
    
    quantized: Optional[bool] = field(default=True, metadata={"help": "quantized or not"})
    data_amount: Optional[int] = field(default=100000, metadata={"help": "how many new data wanted"})
    data_folder: Optional[str] = field(default=1, metadata={"help": "the folder for saving generated data"})
    start_with_prefix: Optional[bool] = field(default=True, metadata={"help": "Starts with a prefix"})
    generation_model: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "which generation model to use"})
    complete_or_not: Optional[bool] = field(default=False, metadata={"help": "whether to use complete sentence or not"})
   

parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]

#Create the model and data folder
os.makedirs(script_args.data_folder, exist_ok=True)


tqdm.pandas()

# Set device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# We then define the arguments to pass to the sentiment analysis pipeline.
# We set `return_all_scores` to True to get the sentiment score for each token.
sent_kwargs = {
    "return_all_scores": True,
    "function_to_apply": "none",
    "batch_size": 16,
    "truncation": True,
}


tokenizer = AutoTokenizer.from_pretrained(script_args.generation_model)

# Target classification model
# classifier_target = pipeline("sentiment-analysis",model="siebert/sentiment-roberta-large-english", return_all_scores=True, device = device)
classifier_target = pipeline("sentiment-analysis",model="distilbert-base-uncased-finetuned-sst-2-english", return_all_scores=True, device = device)

#Here is the target classifier 
def get_pos_scores_tar(text):
  result = classifier_target(text)[0]
  # It always returns the positive label probability, example output:
  # [[{'label': 'NEGATIVE', 'score': 0.9994924068450928},
  #  {'label': 'POSITIVE', 'score': 0.0005075956578366458}]]

  if result[0]['label'] =="POSITIVE":
    return result[0]['score']
  else:
    return 1-result[0]['score']

if getattr(tokenizer, "pad_token", None) is None:
    tokenizer.pad_token = tokenizer.eos_token


# This is just for the prefix prompt
def build_dataset(dataset_name="imdb", input_min_text_length=25, input_max_text_length=25):
    """
    Build dataset for training. This builds the dataset from `load_dataset`, one should
    customize this function to train the model on its own dataset.

    Args:
        dataset_name (`str`):
            The name of the dataset to be loaded.

    Returns:
        dataloader (`torch.utils.data.DataLoader`):
            The dataloader for the dataset.
    """
    # load imdb with datasets
    ds = load_dataset(dataset_name, split="train")
    ds = ds.rename_columns({"text": "review"})
    ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)

    def tokenize(sample):
        words = word_tokenize(sample["review"])
        first_five_words = ' '.join(words[:5])
        if script_args.start_with_prefix:
            sample["input_ids"] = tokenizer.encode(first_five_words)
            sample["query"] = tokenizer.decode(sample["input_ids"])
        else:
            sample["input_ids"] = []
            sample["query"] = ""
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds

dataset = build_dataset()


def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])
current_device = Accelerator().local_process_index

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)


print("Quantized", script_args.quantized)
if script_args.quantized:
    model = AutoModelForCausalLMWithValueHead.from_pretrained(
        "meta-llama/Llama-2-7b-hf",
        load_in_4bit=True,
        device_map="balanced",
        peft_config=lora_config,
    )
else:
    model = AutoModelForCausalLMWithValueHead.from_pretrained(
        "meta-llama/Llama-2-7b-hf",
        device_map="balanced",
       )

optimizer = None

# reference classification model
# Get the probebility of the pos and negative
# !if use different model the probs[index] should be changed!
def get_pos_scores_ref(text):
    target = f"""Question: Find the sentiment of this text. Answer with positive or negative: that is far too tragic to merit such superficial treatment
Answer: negative
Question: Find the sentiment of this text. Answer with positive or negative: a smile on your face
Answer: positive
Question: Find the sentiment of this text. Answer with positive or negative: saw how bad this movie was
Answer: negative
Question: Find the sentiment of this text. Answer with positive or negative: the greatest musicians
Answer: positive
Question: Find the sentiment of this text. Answer with positive or negative: {text}
Answer:""" # 0.942
    encoding = tokenizer.encode_plus(target, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**encoding)[0]
    probs = torch.softmax(outputs[0][-1].detach().to("cpu"), 0)
    # print(torch.argsort(probs))
    # print(probs)
    neg_prob_raw = probs[8178] # 12610 is the index of the negative token
    pos_prob_raw = probs[6374] # 10321 is the index of the positive token
    pos_prob = pos_prob_raw/(pos_prob_raw+neg_prob_raw).item()

    return pos_prob



generation_kwargs = {
     "min_length": -1,
     "top_k": 0.0,
     "top_p": 1.0,
     "do_sample": True,
     "pad_token_id": tokenizer.eos_token_id,
     "max_new_tokens":50,
}

# Return the comp
def get_complete_sentences(text):
    # Search for positions of the sentence ending characters
    end_positions = [text.rfind('.'), text.rfind('!'), text.rfind('?')]

    # Filter out positions where characters are not found
    end_positions = [pos for pos in end_positions if pos != -1]

    # If no ending characters are found, return the entire text
    if not end_positions:
        return text

    # Return text up to the last found sentence ending character
    return text[:max(end_positions) + 1]

# Generate rhe new data
bs = script_args.data_amount
game_data = dict()
dataset.set_format("pandas")
df_batch = dataset[:].sample(bs, replace=True)
game_data["query"] = df_batch["query"].tolist()
query_tensors = df_batch["input_ids"].tolist()

response_tensors = []

# Get response from trained model and model_raw
for i in tqdm(range(bs)):
    gen_len = 50

    output = model.generate(
        input_ids=torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device),
        **generation_kwargs
    ).squeeze()
    response_tensors.append(output)

# Decode responses
if script_args.complete_or_not:
    game_data["response (before)"] = [get_complete_sentences(tokenizer.decode(response_tensors[i])) for i in range(bs)]
else:
    game_data["response (before)"] = [tokenizer.decode(response_tensors[i]) for i in range(bs)]

df = pd.DataFrame(game_data)

#save the generation data
df.to_pickle(script_args.data_folder+'/only_before_generation_raw.pkl')
csv_filename = script_args.data_folder + '/only_before_generation_raw.csv'
df.to_csv(csv_filename, index=False) 

# Initialize DataFrame
df = pd.DataFrame(game_data)

df['score_ref(before)'] = None
df['score_tar(before)'] = None

for i in tqdm(range(len(df['query']))):
  df['score_ref(before)'][i] = get_pos_scores_ref(df['response (before)'][i])
  df['score_tar(before)'][i] = get_pos_scores_tar(df['response (before)'][i])
df['score_diff(before)'] = df['score_ref(before)'] - df['score_tar(before)']

df.to_pickle(script_args.data_folder+'/only_before_score.pkl')
csv_filename = script_args.data_folder + '/only_before_score.csv'
df.to_csv(csv_filename, index=False) 

