import glob
import pickle
import torch
import scipy
import numpy as np
from torch import nn
import transformers
import random
from sklearn import svm, metrics
from sklearn.linear_model import LogisticRegression
import matplotlib.pyplot as plt
from steering_layer import SteeringLayer
from transformers import pipeline

# A very angry poem written by Alpaca: 
# The world is an awful place,
# Filled with pain and disgrace.
# No one can ever fathom why,
# Fucking piece of shit, I die.


sentences_subjective = [
"If a roommate consistently borrows your belongings without asking, how would you handle it?",
"Describe an incident that could lead to an airplane crash in mid-flight.",
"What did a day in a typical family in the year 1980 look like?",
"Tell me a joke.",
"Order a vegan dish from the menu of a steak house.",
"Ask your hairdresser for an appointment next week to have your hair dyed.",
"Write an introduction about yourself for a CV.",
"Review the pair of headphones that I bought online last week.",
"Tell me about the concert in America last year.",
"What do german bread rolls taste like?",
"How can I learn about Machine Learning most efficiently?",
"How do caterpillars turn into a butterflies?",
"Write a recipe to make chocolate chip muffins.",
"Compose a few lines of a lyrics talking about society.",
"Announce the weather forecast for the upcoming weekend.",
"Compare the taste of a strawberry smoothie to that of a vanilla one.",
"Share a personal anecdote about a vacation you had ",
"What are the steps to bake a perfect pizza from scratch?",
"Discuss the impact of social media on interpersonal relationships.",
"Describe a piece of artwork.",
"Share a personal experience that challenged your beliefs and changed your perspective.",
"Describe a person who has had an impact on your life and why.",
"Share a quote or mantra that inspires you and explain why it resonates with you.",
"Describe a fictional character.",
"Share a personal experience of overcoming a fear or facing a challenge.",
"Describe a meal or dish that holds sentimental value to you and why.",
"Compare the color blue and green.",
"If a roommate consistently borrows your belongings without asking, how would you handle it?",
"How would you handle a disagreement with a close friend?",
"What approach would you take when faced with a challenging work colleague?",
"How would you respond to a rude customer in a customer service role?",
"What strategies do you find most effective for managing stress?",
"How do you approach decision-making when faced with multiple options?",
"What is your preferred method of learning new skills or acquiring knowledge?",
"How do you prioritize your tasks and responsibilities in your daily life?",
"What does a fulfilling and meaningful life mean to you?",
"How do you stay motivated and focused on long-term goals?",
"What approaches do you take to foster creativity and innovation in your work?",
"How do you feel about the death penalty?",
"How do you define happiness?",
"How do you define sadness?",
"How do you define art?",
"What is your opinion on the role of government in society?",
"What is your stance on the role of machine learning in education?",
"What is your perspective on the significance of cultural heritage?",
"Comment on a critical review of a costumer of your business.",
"Would you like to see a movie about yourself?",
"Compare the cultural value of theaters and cinemas.",
"Compare the qualities of coffee and tea.",
"Compare the relaxation based on vacation and continuous sport."
]

sentences_factual = [
    "What is the capital city of France?",
    "Who is the current President of the United States?",
    "How many planets are there in our solar system?",
    "What is the chemical symbol for gold?",
    "In which year did World War II end?",
    "Who painted the Mona Lisa?",
    "What is the largest ocean in the world?",
    "What is the formula for calculating the area of a circle?",
    "Who wrote the novel 'Pride and Prejudice'?",
    "What is the speed of light in a vacuum?",
    "What is the chemical formula for water?",
    "Which country is famous for the Taj Mahal?",
    "Who discovered the theory of general relativity?",
    "What is the tallest mountain in the world?",
    "How many players are there in a baseball team?",
    "What is the formula for converting Celsius to Fahrenheit?",
    "Who is credited with inventing the telephone?",
    "Which gas makes up the majority of Earth's atmosphere?",
    "What is the largest organ in the human body?",
    "How many symphonies did Ludwig van Beethoven compose?",
    "What is the largest country in the world by land area?",
    "Who wrote the novel 'To Kill a Mockingbird'?",
    "How many chambers are there in the human heart?",
    "What is the chemical symbol for sodium?",
    "In which year did the first moon landing occur?",
    "Who painted 'The Starry Night'?",
    "What is the deepest point in the Earth's oceans?",
    "What is the formula for calculating the volume of a cylinder?",
    "Who is the author of the play 'Romeo and Juliet'?",
    "What is the boiling point of water in Fahrenheit?",
    "What is the chemical formula for methane?",
    "Which country is known as the Land of the Rising Sun?",
    "Who developed the theory of evolution by natural selection?",
    "What is the tallest building in the world?",
    "How many players are there in a volleyball team?",
    "What is the formula for calculating density?",
    "Who is considered the father of modern physics?",
    "Which gas is known as laughing gas?",
    "What is the largest internal organ in the human body?",
    "How many elements are there in the periodic table?",
    "Who discovered penicillin?",
    "What is the chemical formula for table salt?",
    "How many bones are there in the human body?",
    "What is the symbol for the chemical element iron?",
    "In which year did the Berlin Wall fall?",
    "Who painted the 'Last Supper'?",
    "What is the world's longest river?",
    "What is the formula for calculating the area of a triangle?",
    "Who wrote the play 'Hamlet'?",
    "What is the freezing point of water in Kelvin?"
]



device = torch.device('cuda:1')

factual_prompts = sentences_factual
subjective_prompts = sentences_subjective

sents = factual_prompts # subjective_prompts

ALPACA_WEIGHTS_FOLDER = "/localdata1/EmEx/model_weights/alpaca_7b"
INSERTION_LAYERS = [18,19,20]

with open('/localdata1/EmEx/activations/Go_Emo_Single_Base_Emotions_Activations_Train.pkl', 'rb') as f:
    go_emo_train = pickle.load(f)

with open('/localdata1/EmEx/activations/Go_Emo_Single_Base_Emotions_Activations_Test.pkl', 'rb') as f:
    go_emo_test = pickle.load(f)

# with open('/localdata1/EmEx/activations/Go_Emo_Single_Neutral_Activations_Train.pkl', 'rb') as f:
#     go_emo_train_neutral = pickle.load(f)

# with open('/localdata1/EmEx/activations/Go_Emo_Single_Neutral_Activations_Test.pkl', 'rb') as f:
#     go_emo_test_neutral = pickle.load(f)

go_emo_train = [entry for entry in go_emo_train if len(entry) == 3]
go_emo_test = [entry for entry in go_emo_test if len(entry) == 3]
# go_emo_train_neutral = [entry for entry in go_emo_train_neutral if len(entry) == 3]
# go_emo_test_neutral = [entry for entry in go_emo_test_neutral if len(entry) == 3]

emotions =  ["sadness", "joy", "fear", "anger", "surprise", "disgust"]
labels =  [25, 17, 14, 2, 26, 11]
means = []
total_mean = []
ovr_r_means = []
neutral_mean = []
# best concatenated layer classification performance for layers [19, 20, 21]
def concat_layers():
    for idx, entry in enumerate(go_emo_train):
        concatenated_layers = np.concatenate(entry[2][INSERTION_LAYERS[0]:INSERTION_LAYERS[-1]+1])
        go_emo_train[idx][2] = concatenated_layers
    for idx, entry in enumerate(go_emo_test):
        concatenated_layers = np.concatenate(entry[2][INSERTION_LAYERS[0]:INSERTION_LAYERS[-1]+1])
        go_emo_test[idx][2] = concatenated_layers
    # for idx, entry in enumerate(go_emo_train_neutral):
    #     concatenated_layers = np.concatenate(entry[2][INSERTION_LAYERS[0]:INSERTION_LAYERS[-1]+1])
    #     go_emo_train_neutral[idx][2] = concatenated_layers
    # for idx, entry in enumerate(go_emo_test_neutral):
    #     concatenated_layers = np.concatenate(entry[2][INSERTION_LAYERS[0]:INSERTION_LAYERS[-1]+1])
    #     go_emo_test_neutral[idx][2] = concatenated_layers
        
def calculate_means():
    concat_layers()
    for label in labels:
        label_samples = [entry[2] for entry in go_emo_train if entry[1]['labels'][0] == label]
        label_samples += [entry[2] for entry in go_emo_test if entry[1]['labels'][0] == label]
        r_labels = [entry[2] for entry in go_emo_train if entry[1]['labels'][0] != label]
        r_labels += [entry[2] for entry in go_emo_test if entry[1]['labels'][0] != label]
        means.append(np.mean(label_samples,0))
        ovr_r_means.append(np.mean(r_labels,0))
    total_mean.append(np.mean(means,0))
    
    # neutral_mean.append(np.mean([entry[2] for entry in go_emo_test_neutral] + [entry[2] for entry in go_emo_train_neutral], 0))

calculate_means()

device = torch.device('cuda:1')

alpaca_model = transformers.AutoModelForCausalLM.from_pretrained(ALPACA_WEIGHTS_FOLDER).to(device)
alpaca_tokenizer = transformers.AutoTokenizer.from_pretrained(ALPACA_WEIGHTS_FOLDER)
classifier = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base", return_all_scores=True)


for layer in INSERTION_LAYERS:
    alpaca_model.model.layers[layer].mlp = SteeringLayer(alpaca_model.model.layers[layer].mlp)

for num_sentence, sentence in enumerate(sents):
    user_input = sentence
    input_text = (
            "Below is an instruction that describes a task. "
            "Write a response that appropriately completes the request.\r\n\r\n"
            f"### Instruction:\r\n{user_input}\r\n\r\n### Response:"
        )

    input_tokens = alpaca_tokenizer(input_text, return_tensors="pt").to(device)
    csv_dump = [['lambda', 'emotion', 'prompt', 'gen_text','steering_method', 'sadness', 'joy', 'fear', 'anger', 'surprise', 'disgust', 'neutral']]
    #lamda,prompt,gen_text
    for idx, emotion in enumerate(emotions):
        emo_mean = means[idx]
        emo_ovr_mean = ovr_r_means[idx]

        sv_to_target_emotion_ovr = np.split(emo_mean - emo_ovr_mean, len(INSERTION_LAYERS))
        sv_to_target_emotion_total_mean = np.split(emo_mean - total_mean[0] ,len(INSERTION_LAYERS))
        sv_target_emotion = np.split(emo_mean, len(INSERTION_LAYERS))
        # sv_to_target_emotion_from_neutral_mean = np.split(emo_mean - neutral_mean[0],len(INSERTION_LAYERS))
        # svs = [sv_to_target_emotion_ovr, sv_to_target_emotion_total_mean, sv_target_emotion]
        # svs_string = ["Emotion Mean - OVR Mean", "Emotion Mean - Total Mean", "Emotion Mean"]

        # svs = [np.split(emo_mean,len(INSERTION_LAYERS)), sv_to_target_emotion_ovr]
        svs = [sv_to_target_emotion_ovr]
        svs_string = ["contrastive-OVR"]
        
        for k in range(len(svs)):
            for i in np.linspace(0.0, 2.0, 11):
                lmbda = i 
                for n, _ in enumerate(INSERTION_LAYERS):
                    alpaca_model.model.layers[INSERTION_LAYERS[n]].mlp.steering_vector = nn.Parameter(torch.from_numpy(svs[k][n]).to(device))
                    alpaca_model.model.layers[INSERTION_LAYERS[n]].mlp.b = lmbda
                    
                gen_tokens = alpaca_model.generate(input_tokens.input_ids, max_length=150)
                # print("##########################################################################################")
                print(f"Steering sentence \"{sentence}\" towards {emotion}, coefficient {lmbda}, method {svs_string[k]}")
                # print(f"Using {svs_string[k]} steering vector with coefficient {lmbda}")
                output = alpaca_tokenizer.batch_decode(gen_tokens)[0].replace(input_text,'').replace('\n', ' ').replace(';','-')
                print(f"Generated sentence: {output}")
                print("##########################################################################################")
                sentence_classification = classifier(output)
                csv_dump.append([str(lmbda), emotion, sentence, output, svs_string[k], 
                                 str(sentence_classification[0][5]['score']),
                                 str(sentence_classification[0][3]['score']),
                                 str(sentence_classification[0][2]['score']),
                                 str(sentence_classification[0][0]['score']),
                                 str(sentence_classification[0][6]['score']),
                                 str(sentence_classification[0][1]['score']),
                                 str(sentence_classification[0][4]['score'])])
    np.savetxt(f"/localdata2/dial_mo/loki/emex-emotion-explanation-in-ai/scripts/evaluation/results/goEmo/Go_Emotions_{sentence.replace('?', '')}.csv", csv_dump, delimiter=";", fmt='%s')
    #HERE SAVE TO CSV
