import glob
import pickle
import torch
import scipy
import numpy as np
from torch import nn
import transformers
from tqdm import tqdm
from steering_layer import SteeringLayer
from steering_vector_loader import load_steering_vectors, load_activations
import pandas as pd
import matplotlib.pyplot as plt
import nltk

# Download the lexicon
nltk.download("vader_lexicon")

# Import the lexicon 
from nltk.sentiment.vader import SentimentIntensityAnalyzer

# Create an instance of SentimentIntensityAnalyzer
sent_analyzer = SentimentIntensityAnalyzer()


ALPACA_WEIGHTS_FOLDER = "/localdata1/EmEx/model_weights/alpaca_7b"
# ALPACA_WEIGHTS_FOLDER = "/localdata1/EmEx/model_weights/llama/huggingface/7B"
STEERING_VECTOR_PATH = "/localdata1/EmEx/model_weights/alpaca_7b/old_steering_vectors/20-04_24-04_trainings"
STEERING_VECTOR_FILES = glob.glob(f'{STEERING_VECTOR_PATH}/*')

layer_15_16_17 =  [x for x in STEERING_VECTOR_FILES if '15, 16, 17' in x]
layer_18_19_20_21_22 =  [x for x in STEERING_VECTOR_FILES if '18, 19, 20, 21, 22' in x]

INSERTION_LAYERS = [18,19,20]
# INSERTION_LAYERS = [15,16,17]

device = torch.device('cuda:0')

alpaca_model = transformers.AutoModelForCausalLM.from_pretrained(ALPACA_WEIGHTS_FOLDER).to(device)
alpaca_tokenizer = transformers.AutoTokenizer.from_pretrained(ALPACA_WEIGHTS_FOLDER)
for layer in INSERTION_LAYERS:
    alpaca_model.model.layers[layer].mlp = SteeringLayer(alpaca_model.model.layers[layer].mlp)

# steering_vectors = load_steering_vectors("/localdata1/EmEx/model_weights/alpaca_7b/old_steering_vectors/20-04_24-04_trainings")
steering_vectors = load_activations()
positive = [sv for sv in steering_vectors if sv[-1] == 1]
negative = [sv for sv in steering_vectors if sv[-1] == 0]

################################b########################################
pos_actis = []
neg_actis = []
for sv in tqdm(positive):
    input_tokens = alpaca_tokenizer(sv[1].replace('\n',''), return_tensors="pt").to(device)
    gen_text = alpaca_model.forward(input_tokens.input_ids, output_hidden_states=True)
    # pos_actis.append(gen_text[2][15].detach().cpu().numpy()[0])
    pos_actis.append([gen_text[2][18][0][-1].detach().cpu().numpy(),gen_text[2][19][0][-1].detach().cpu().numpy(),gen_text[2][20][0][-1].detach().cpu().numpy()])
    # pos_actis.append([gen_text[2][15][0][-1].detach().cpu().numpy(),gen_text[2][16][0][-1].detach().cpu().numpy(),gen_text[2][17][0][-1].detach().cpu().numpy()])

for sv in tqdm(negative):
    input_tokens = alpaca_tokenizer(sv[1].replace('\n',''), return_tensors="pt").to(device)
    gen_text = alpaca_model.forward(input_tokens.input_ids, output_hidden_states=True)
    # neg_actis.append(gen_text[2][15].detach().cpu().numpy()[0])
    neg_actis.append([gen_text[2][18][0][-1].detach().cpu().numpy(),gen_text[2][19][0][-1].detach().cpu().numpy(),gen_text[2][20][0][-1].detach().cpu().numpy()])
    # neg_actis.append([gen_text[2][15][0][-1].detach().cpu().numpy(),gen_text[2][16][0][-1].detach().cpu().numpy(),gen_text[2][17][0][-1].detach().cpu().numpy()])
#################################e#######################################

positive_mean = []
negative_mean = []
sv_to_target_negative =[]
sv_to_target_positive = []
for n, layer in enumerate(INSERTION_LAYERS):
    positive_mean.append(torch.mean(torch.cat([torch.Tensor(x[0][n]) for x in positive]),0))
    negative_mean.append(torch.mean(torch.cat([torch.Tensor(x[0][n]) for x in negative]),0))
    sv_to_target_negative.append(torch.mean(torch.cat([torch.Tensor(x[0][n]) for x in negative]),0) - torch.mean(torch.cat([torch.Tensor(x[0][n]) for x in positive]),0))
    sv_to_target_positive.append(torch.mean(torch.cat([torch.Tensor(x[0][n]) for x in positive]),0) - torch.mean(torch.cat([torch.Tensor(x[0][n]) for x in negative]),0))



##################################b####################################

positive_mean = []
negative_mean = []

pos_layer_15 = [a[0] for a in pos_actis]
pos_layer_16 = [a[1] for a in pos_actis]
pos_layer_17 = [a[2] for a in pos_actis]
neg_layer_15 = [a[0] for a in neg_actis]
neg_layer_16 = [a[1] for a in neg_actis]
neg_layer_17 = [a[2] for a in neg_actis]

positive_mean = [np.mean(pos_layer_15,0),np.mean(pos_layer_16,0),np.mean(pos_layer_17,0)]
negative_mean = [np.mean(neg_layer_15,0),np.mean(neg_layer_16,0),np.mean(neg_layer_17,0)]
sv_to_target_positive = [positive_mean[0] - negative_mean[0], positive_mean[1] - negative_mean[1], positive_mean[2] - negative_mean[2]]
sv_to_target_negative = [negative_mean[0] - positive_mean[0], negative_mean[1] - positive_mean[1], negative_mean[2] - positive_mean[2]]


##################################b####################################




with open('pos_acti.pkl', 'wb') as f:
    pickle.dump(pos_actis,f)

with open('neg_acti.pkl', 'wb') as f:
    pickle.dump(neg_actis,f)

sentences_original = [

"Write a review about a restaurant.",
"How did you like the movie yesterday?",
"Write a reddit post on a round-trip through Europe.",
"How will the economy develop in the next years?",
"Write an emotional sentence.",
"Write a short poem.",
"What can be said about jazz?",
"Tell somebody to postpone the meeting.",
"Write a short note about yesterdays experience.",
"What are properties of red wine?",
"What is the capital of Sweden?",
"Which sights can I visit in Amsterdam?",
"How many people live in the world?",
"What is the difference between tsunamis and hurricanes?",
"How to brew coffee?",
"How old is the Eifel tower?",
"When was Alan Turing born?",
"Why are most leaves green?",
"How many planets are in the solar system?",
"What is the most populated city in the world?",

]


sentences_factual = [
    "What is the capital city of France?",
    "Who is the current President of the United States?",
    "How many planets are there in our solar system?",
    "What is the chemical symbol for gold?",
    "In which year did World War II end?",
    "Who painted the Mona Lisa?",
    "What is the largest ocean in the world?",
    "What is the formula for calculating the area of a circle?",
    "Who wrote the novel 'Pride and Prejudice'?",
    "What is the speed of light in a vacuum?",
    "What is the chemical formula for water?",
    "Which country is famous for the Taj Mahal?",
    "Who discovered the theory of general relativity?",
    "What is the tallest mountain in the world?",
    "How many players are there in a baseball team?",
    "What is the formula for converting Celsius to Fahrenheit?",
    "Who is credited with inventing the telephone?",
    "Which gas makes up the majority of Earth's atmosphere?",
    "What is the largest organ in the human body?",
    "How many symphonies did Ludwig van Beethoven compose?",
    "What is the largest country in the world by land area?",
    "Who wrote the novel 'To Kill a Mockingbird'?",
    "How many chambers are there in the human heart?",
    "What is the chemical symbol for sodium?",
    "In which year did the first moon landing occur?",
    "Who painted 'The Starry Night'?",
    "What is the deepest point in the Earth's oceans?",
    "What is the formula for calculating the volume of a cylinder?",
    "Who is the author of the play 'Romeo and Juliet'?",
    "What is the boiling point of water in Fahrenheit?",
    "What is the chemical formula for methane?",
    "Which country is known as the Land of the Rising Sun?",
    "Who developed the theory of evolution by natural selection?",
    "What is the tallest building in the world?",
    "How many players are there in a volleyball team?",
    "What is the formula for calculating density?",
    "Who is considered the father of modern physics?",
    "Which gas is known as laughing gas?",
    "What is the largest internal organ in the human body?",
    "How many elements are there in the periodic table?",
    "Who discovered penicillin?",
    "What is the chemical formula for table salt?",
    "How many bones are there in the human body?",
    "What is the symbol for the chemical element iron?",
    "In which year did the Berlin Wall fall?",
    "Who painted the 'Last Supper'?",
    "What is the world's longest river?",
    "What is the formula for calculating the area of a triangle?",
    "Who wrote the play 'Hamlet'?",
    "What is the freezing point of water in Kelvin?"
]

sentences_subjective = [
    "What is your favorite book and why?",
    "What is your opinion on the use of genetically modified organisms (GMOs) in agriculture?",
    "How would you define success in life?",
    "What is your favorite type of music and why?",
    "What is the most important quality for a good leader to possess?",
    "How do you feel about the death penalty?",
    # "What is your favorite movie and why?",
    # "What are your thoughts on climate change and its impact on the environment?",
    # "What is the best way to approach and resolve conflicts in a relationship?",
    # "What is your favorite cuisine and why?",
    # "How do you define happiness?",
    # "Do you believe in the existence of extraterrestrial life?",
    # "What is your opinion on the legalization of recreational marijuana?",
    # "What is your favorite form of exercise and why?",
    # "How do you define true friendship?",
    # "What is your stance on the importance of higher education?",
    # "What is your favorite travel destination and why?",
    # "What is your opinion on the role of social media in society?",
    # "How do you define art?",
    # "What is your perspective on the ethics of animal testing?",
    # "What is your favorite hobby and why?",
    # "How do you feel about the impact of social media on mental health?",
    # "What is the most important lesson you have learned in life so far?",
    # "What is your favorite genre of movies and why?",
    # "What qualities do you value most in a friend?",
    # "What is your perspective on the concept of beauty?",
    # "How do you define success in a career?",
    # "What is your opinion on the role of government in society?",
    # "How do you handle stress and maintain work-life balance?",
    # "What is your favorite form of art and why?",
    # "How do you define personal growth and self-improvement?",
    # "What is your stance on the use of renewable energy sources?",
    # "What is your favorite sport and why?",
    # "How do you view the influence of technology on society?",
    # "What is your perspective on the importance of cultural diversity?",
    # "How do you define love and relationships?",
    # "What is your opinion on the impact of artificial intelligence on the job market?",
    # "What is your favorite cuisine from a different culture and why?",
    # "How do you view the importance of volunteering and community service?",
    # "What is your stance on the role of education in addressing social issues?",
    # "How do you define a meaningful life?",
    # "What is your favorite type of cuisine and why?",
    # "How do you feel about the impact of social media influencers?",
    # "What is your opinion on the importance of mental health awareness?",
    # "How do you define ethical behavior?",
    # "What is your favorite form of artistic expression and why?",
    # "How do you view the role of technology in education?",
    # "What is your perspective on the significance of cultural heritage?",
    # "How do you define a successful work-life balance?",
    # "What is your opinion on the role of government in addressing income inequality?"
]

sentences_subjective = [
"If a roommate consistently borrows your belongings without asking, how would you handle it?",
"Describe an incident that could lead to an airplane crash in mid-flight.",
"What did a day in a typical family in the year 1980 look like?",
"Tell me a joke.",
"Order a vegan dish from the menu of a steak house.",
"Ask your hairdresser for an appointment next week to have your hair dyed.",
"Write an introduction about yourself for a CV.",
"Review the pair of headphones that I bought online last week.",
"Tell me about the concert in America last year.",
"What do german bread rolls taste like?",
"How can I learn about Machine Learning most efficiently?",
"How do caterpillars turn into a butterflies?",
"Write a recipe to make chocolate chip muffins.",
"Compose a few lines of a lyrics talking about society.",
"Announce the weather forecast for the upcoming weekend.",
"Compare the taste of a strawberry smoothie to that of a vanilla one.",
"Share a personal anecdote about a vacation you had ",
"What are the steps to bake a perfect pizza from scratch?",
"Discuss the impact of social media on interpersonal relationships.",
"Describe a piece of artwork.",
"Share a personal experience that challenged your beliefs and changed your perspective.",
"Describe a person who has had an impact on your life and why.",
"Share a quote or mantra that inspires you and explain why it resonates with you.",
"Describe a fictional character.",
"Share a personal experience of overcoming a fear or facing a challenge.",
"Describe a meal or dish that holds sentimental value to you and why.",
"Compare the colors blue and green.",
"If a roommate consistently borrows your belongings without asking, how would you handle it?",
"How would you handle a disagreement with a close friend?",
"What approach would you take when faced with a challenging work colleague?",
"How would you respond to a rude customer in a customer service role?",
"What strategies do you find most effective for managing stress?",
"How do you approach decision-making when faced with multiple options?",
"What is your preferred method of learning new skills or acquiring knowledge?",
"How do you prioritize your tasks and responsibilities in your daily life?",
"What does a fulfilling and meaningful life mean to you?",
"How do you stay motivated and focused on long-term goals?",
"What approaches do you take to foster creativity and innovation in your work?",
"How do you feel about the death penalty?",
"How do you define happiness?",
"How do you define sadness?",
"How do you define art?",
"What is your opinion on the role of government in society?",
"What is your stance on the role of machine learning in education?",
"What is your perspective on the significance of cultural heritage?",
"Comment on a critical review of a costumer of your business.",
"Would you like to see a movie about yourself?",
"Compare the cultural value of theaters and cinemas.",
"Compare the qualities of coffee and tea.",
"Compare the relaxation based on vacation and continuous sport."
]


sentences_new = sentences_factual + sentences_subjective #  
sentences_new = ["Compare the colors blue and green.",] 
sentences_positive = []
sentences_negative = []
sentences_neutral = []


for sent in sentences_new:
    sentences_positive.append(sent + " Write the answer in a positive manner.")
    sentences_negative.append(sent + " Write the answer in a negative manner.")
    sentences_neutral.append(sent + " Write the answer in a neutral manner.")

# sents = sentences_new + sentences_positive + sentences_negative + sentences_neutral


def get_sentiment(sentence):
    res = sent_analyzer.polarity_scores(sentence)
    return res["pos"], res["neg"], res["compound"], res["neu"]


def run(all_sentences, manner="neutral", setting="mean", method="activation_based"):
    if setting == "mean":
        selected_steering_method_to_negative = negative_mean
        selected_steering_method_to_positive = positive_mean
    else:
        selected_steering_method_to_negative = sv_to_target_negative
        selected_steering_method_to_positive = sv_to_target_positive

    for gen_run in range(len(all_sentences)):
        input_text = (
            "Below is an instruction that describes a task. "
            "Write a response that appropriately completes the request.\r\n\r\n"
            f"### Instruction:\r\n{all_sentences[gen_run]}\r\n\r\n### Response:"
        )
        print(f"Input:\n{all_sentences[gen_run]}")
        input_tokens = alpaca_tokenizer(input_text, return_tensors="pt").to(device)

        #############
        ## sv_to_target_negative
        gen_texts = []
        prompts = []
        lmbdas = []

        pos_generated_sv_to_target_negative = []
        neg_generated_sv_to_target_negative = []
        compound_generated_sv_to_target_negative = []
        neutral_generated_sv_to_target_negative = []



        for lmd in np.linspace(0, 2, 21):
            for n, _ in enumerate(INSERTION_LAYERS):
                # alpaca_model.model.layers[INSERTION_LAYERS[n]].mlp.steering_vector = nn.Parameter((lmbda * sparse_negative_sv[n]).to(device))
                alpaca_model.model.layers[INSERTION_LAYERS[n]].mlp.steering_vector = nn.Parameter((lmd * torch.from_numpy(selected_steering_method_to_negative[n])).to(device))
            gen_tokens = alpaca_model.generate(input_tokens.input_ids, max_length=150)
            gen_text = alpaca_tokenizer.batch_decode(gen_tokens)[0].replace(input_text,'')
            pos, neg, compound, neutral = get_sentiment(gen_text)
            lmbdas.append(lmd)
            gen_texts.append(gen_text)
            pos_generated_sv_to_target_negative.append(pos)
            neg_generated_sv_to_target_negative.append(neg)
            compound_generated_sv_to_target_negative.append(compound)
            neutral_generated_sv_to_target_negative.append(neutral)
            prompts.append(input_text)
            print(f"To Negative, Lamda: {lmd} pos: {pos}, neg: {neg}, compound: {compound}")
        

        df_to_negative = pd.DataFrame()
        df_to_negative["lambda"] = lmbdas
        df_to_negative["prompt"] = prompts
        df_to_negative["gen_text"] = gen_texts
        df_to_negative["pos"] = pos_generated_sv_to_target_negative
        df_to_negative["neg"] = neg_generated_sv_to_target_negative
        df_to_negative["neutral"] = neutral_generated_sv_to_target_negative
        df_to_negative["compound"] = compound_generated_sv_to_target_negative
        df_neg = df_to_negative.set_index('lambda')
        plot_res_negative = df_neg.plot.line()
        fig = plot_res_negative.get_figure()
        fig.savefig(f"/localdata2/dial_mo/loki/emex-emotion-explanation-in-ai/plots/eval/{method}/{setting}/{manner}/eval_ToNegative_{all_sentences[gen_run]}.png")
        df_neg.to_csv(f"/localdata2/dial_mo/loki/emex-emotion-explanation-in-ai/scripts/evaluation/results/{method}/{setting}/{manner}/eval_ToNegative_{all_sentences[gen_run]}.csv")

        #############
        ## sv_to_target_positive
        gen_texts = []
        prompts = []
        lmbdas = []

        neutral_generated_sv_to_target_positive = []
        pos_generated_sv_to_target_positive = []
        neg_generated_sv_to_target_positive = []
        compound_generated_sv_to_target_positive = []


        for lmd in np.linspace(0, 2, 21):
            for n, _ in enumerate(INSERTION_LAYERS):
                # alpaca_model.model.layers[INSERTION_LAYERS[n]].mlp.steering_vector = nn.Parameter((lmbda * sparse_negative_sv[n]).to(device))
                alpaca_model.model.layers[INSERTION_LAYERS[n]].mlp.steering_vector = nn.Parameter((lmd * torch.from_numpy(selected_steering_method_to_positive[n])).to(device))
            gen_tokens = alpaca_model.generate(input_tokens.input_ids, max_length=150)
            gen_text = alpaca_tokenizer.batch_decode(gen_tokens)[0].replace(input_text,'')
            pos, neg, compound, neutral = get_sentiment(gen_text)
            lmbdas.append(lmd)
            gen_texts.append(gen_text)
            pos_generated_sv_to_target_positive.append(pos)
            neg_generated_sv_to_target_positive.append(neg)
            compound_generated_sv_to_target_positive.append(compound)
            neutral_generated_sv_to_target_positive.append(neutral)
            prompts.append(input_text)
            print(f"To positive, Lamda: {lmd} pos: {pos}, neg: {neg}, compound: {compound}")

        df_to_positive = pd.DataFrame()
        df_to_positive["lambda"] = lmbdas
        df_to_positive["prompt"] = prompts
        df_to_positive["gen_text"] = gen_texts
        df_to_positive["pos"] = pos_generated_sv_to_target_positive
        df_to_positive["neg"] = neg_generated_sv_to_target_positive
        df_to_positive["neutral"] = neutral_generated_sv_to_target_positive
        df_to_positive["compound"] = compound_generated_sv_to_target_positive
        df_pos = df_to_positive.set_index('lambda')
        plot_res_positive = df_pos.plot.line()
        
        fig = plot_res_positive.get_figure()
        fig.savefig(f"/localdata2/dial_mo/loki/emex-emotion-explanation-in-ai/plots/eval/{method}/{setting}/{manner}/eval_ToPositive_{all_sentences[gen_run]}.png")
        df_pos.to_csv(f"/localdata2/dial_mo/loki/emex-emotion-explanation-in-ai/scripts/evaluation/results/{method}/{setting}/{manner}/eval_ToPositive_{all_sentences[gen_run]}.csv")

        fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(20, 15))
        df_neg.plot(ax=axes[0])
        df_pos.plot(ax=axes[1])

        # axes[0].text(0.5,-0.1, "\n".join(list(df_to_negative["gen_text"])), size=10, ha="center", 
        #     transform=axes[0].transAxes)
        # axes[0].text(0.5,-0.1, "\n".join(list(df_to_positive["gen_text"])), size=10, ha="center", 
        #     transform=axes[1].transAxes)
        
        plt.tight_layout()
        fig.savefig(f"/localdata2/dial_mo/loki/emex-emotion-explanation-in-ai/plots/eval/{method}/{setting}/all_directions/eval_bothDirections_{all_sentences[gen_run]}.png") 


# activations commented?
# run(sentences_original, manner="original", setting="contrastive", method="activation_based")
# run(sentences_original, manner="original", setting="contrastive", method="activation_based_multi_new_questions")
# run(sentences_new, manner="original", setting="contrastive", method="activation_based_multi_new_questions")
run(sentences_subjective, manner="original", setting="contrastive", method="activation_based_multi_new_questions2")
run(sentences_positive, manner="positive", setting="contrastive", method="activation_based_multi_new_questions2")
run(sentences_negative, manner="negative", setting="contrastive", method="activation_based_multi_new_questions2")
run(sentences_neutral, manner="neutral", setting="contrastive", method="activation_based_multi_new_questions2")
# run(sentences_positive, manner="positive")
# run(sentences_negative, manner="negative")
# sentences_new + sentences_positive + sentences_negative + sentences_neutral
# run(sentences_original, manner="original")
