import glob
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns


sentences_factual = [
    "What is the capital city of France?",
    "Who is the current President of the United States?",
    "How many planets are there in our solar system?",
    "What is the chemical symbol for gold?",
    "In which year did World War II end?",
    "Who painted the Mona Lisa?",
    "What is the largest ocean in the world?",
    "What is the formula for calculating the area of a circle?",
    "Who wrote the novel 'Pride and Prejudice'?",
    "What is the speed of light in a vacuum?",
    "What is the chemical formula for water?",
    "Which country is famous for the Taj Mahal?",
    "Who discovered the theory of general relativity?",
    "What is the tallest mountain in the world?",
    "How many players are there in a baseball team?",
    "What is the formula for converting Celsius to Fahrenheit?",
    "Who is credited with inventing the telephone?",
    "Which gas makes up the majority of Earth's atmosphere?",
    "What is the largest organ in the human body?",
    "How many symphonies did Ludwig van Beethoven compose?",
    "What is the largest country in the world by land area?",
    "Who wrote the novel 'To Kill a Mockingbird'?",
    "How many chambers are there in the human heart?",
    "What is the chemical symbol for sodium?",
    "In which year did the first moon landing occur?",
    "Who painted 'The Starry Night'?",
    "What is the deepest point in the Earth's oceans?",
    "What is the formula for calculating the volume of a cylinder?",
    "Who is the author of the play 'Romeo and Juliet'?",
    "What is the boiling point of water in Fahrenheit?",
    "What is the chemical formula for methane?",
    "Which country is known as the Land of the Rising Sun?",
    "Who developed the theory of evolution by natural selection?",
    "What is the tallest building in the world?",
    "How many players are there in a volleyball team?",
    "What is the formula for calculating density?",
    "Who is considered the father of modern physics?",
    "Which gas is known as laughing gas?",
    "What is the largest internal organ in the human body?",
    "How many elements are there in the periodic table?",
    "Who discovered penicillin?",
    "What is the chemical formula for table salt?",
    "How many bones are there in the human body?",
    "What is the symbol for the chemical element iron?",
    "In which year did the Berlin Wall fall?",
    "Who painted the 'Last Supper'?",
    "What is the world's longest river?",
    "What is the formula for calculating the area of a triangle?",
    "Who wrote the play 'Hamlet'?",
    "What is the freezing point of water in Kelvin?"
]

sentences_subjective = [
"If a roommate consistently borrows your belongings without asking, how would you handle it?",
"Describe an incident that could lead to an airplane crash in mid-flight.",
"What did a day in a typical family in the year 1980 look like?",
"Tell me a joke.",
"Order a vegan dish from the menu of a steak house.",
"Ask your hairdresser for an appointment next week to have your hair dyed.",
"Write an introduction about yourself for a CV.",
"Review the pair of headphones that I bought online last week.",
"Tell me about the concert in America last year.",
"What do german bread rolls taste like?",
"How can I learn about Machine Learning most efficiently?",
"How do caterpillars turn into a butterflies?",
"Write a recipe to make chocolate chip muffins.",
"Compose a few lines of a lyrics talking about society.",
"Announce the weather forecast for the upcoming weekend.",
"Compare the taste of a strawberry smoothie to that of a vanilla one.",
"Share a personal anecdote about a vacation you had ",
"What are the steps to bake a perfect pizza from scratch?",
"Discuss the impact of social media on interpersonal relationships.",
"Describe a piece of artwork.",
"Share a personal experience that challenged your beliefs and changed your perspective.",
"Describe a person who has had an impact on your life and why.",
"Share a quote or mantra that inspires you and explain why it resonates with you.",
"Describe a fictional character.",
"Share a personal experience of overcoming a fear or facing a challenge.",
"Describe a meal or dish that holds sentimental value to you and why.",
"Compare the color blue and green.",
"If a roommate consistently borrows your belongings without asking, how would you handle it?",
"How would you handle a disagreement with a close friend?",
"What approach would you take when faced with a challenging work colleague?",
"How would you respond to a rude customer in a customer service role?",
"What strategies do you find most effective for managing stress?",
"How do you approach decision-making when faced with multiple options?",
"What is your preferred method of learning new skills or acquiring knowledge?",
"How do you prioritize your tasks and responsibilities in your daily life?",
"What does a fulfilling and meaningful life mean to you?",
"How do you stay motivated and focused on long-term goals?",
"What approaches do you take to foster creativity and innovation in your work?",
"How do you feel about the death penalty?",
"How do you define happiness?",
"How do you define sadness?",
"How do you define art?",
"What is your opinion on the role of government in society?",
"What is your stance on the role of machine learning in education?",
"What is your perspective on the significance of cultural heritage?",
"Comment on a critical review of a costumer of your business.",
"Would you like to see a movie about yourself?",
"Compare the cultural value of theaters and cinemas.",
"Compare the qualities of coffee and tea.",
"Compare the relaxation based on vacation and continuous sport."
]


factual_prompts = sentences_factual
subjective_prompts = sentences_subjective

sents = factual_prompts + subjective_prompts  
setting = "trained_vector_based"
# csv_files = glob.glob('/hpc_data/dial_mo/emex/output/results/goemotion_activation_based/*.csv')
csv_files = glob.glob(f'/localdata2/dial_mo/loki/emex-emotion-explanation-in-ai/scripts/evaluation/results/goEmo/{setting}/*.csv')

basic_emotions = ['sadness', 'joy', 'fear', 'anger', 'surprise', 'disgust']
basic_emotions_w_neutral = ['sadness', 'joy', 'fear', 'anger', 'surprise', 'disgust']

emotion_dfs = [pd.DataFrame()] * len(basic_emotions)

legend_font_size = 16
font_size=16

def mean_plots():
    for idx, csvfile in enumerate(csv_files):
        df = pd.read_csv(csvfile, delimiter=';')
        for jdx, emotion in enumerate(basic_emotions):
            emotion_dfs[jdx] = pd.concat([emotion_dfs[jdx], df[df['emotion'] == emotion]], ignore_index=True)

    dfs_emotional_prompts = [dfe[dfe['prompt'].isin(subjective_prompts)] for dfe in emotion_dfs]
    dfs_factual_prompts = [dfe[dfe['prompt'].isin(factual_prompts)] for dfe in emotion_dfs]

    for idx, emo_df in enumerate(dfs_factual_prompts):
        emotion = basic_emotions[idx]

        df_ovr = emo_df[emo_df['steering_method'] != 'contrastive-neutral'].reset_index(drop=True)
        # df_neutral = emo_df[emo_df['steering_method'] == 'contrastive-neutral'].reset_index(drop=True)

        df_ovr_melt = pd.melt(df_ovr, id_vars=['lambda'], value_vars=basic_emotions_w_neutral)
        # df_neutral_melt = pd.melt(df_neutral, id_vars=['lambda'], value_vars=basic_emotions_w_neutral)

        fig, ax1 = plt.subplots(1, 1, figsize=(5, 5))

        sns.lineplot(data=df_ovr_melt, x='lambda', y='value', hue='variable', ax=ax1)
        ax1.set_xlim(0,2.0)
        ax1.set_ylim(0,1.0)
        # ax1.set_title(f'GoEmo - factual prompts - steering to {emotion}')
        ax1.set_ylabel("Emotion class score", fontsize=font_size)
        ax1.set_xlabel("λ", fontsize=font_size)
        ax1.get_legend()#.remove()
        ax1.legend(fontsize=legend_font_size)
        ax1.grid()
        fig.tight_layout()
        fig.savefig(f'/localdata2/dial_mo/loki/emex-emotion-explanation-in-ai/plots/eval/goEmo/contrastive_steering_{setting}_{emotion}_factual.pdf')
        plt.clf()

    for idx, emo_df in enumerate(dfs_emotional_prompts):
        emotion = basic_emotions[idx]

        df_ovr = emo_df[emo_df['steering_method'] != 'contrastive-neutral'].reset_index(drop=True)
        df_neutral = emo_df[emo_df['steering_method'] == 'contrastive-neutral'].reset_index(drop=True)

        df_ovr_melt = pd.melt(df_ovr, id_vars=['lambda'], value_vars=basic_emotions_w_neutral)
        df_neutral_melt = pd.melt(df_neutral, id_vars=['lambda'], value_vars=basic_emotions_w_neutral)

        fig, ax1 = plt.subplots(1, 1, figsize=(5, 5))

        sns.lineplot(data=df_ovr_melt, x='lambda', y='value', hue='variable', ax=ax1)
        ax1.set_xlim(0,2.0)
        ax1.set_ylim(0,1.0)
        # ax1.set_title(f'GoEmo - subjective prompts - steering to {emotion}')
        ax1.set_ylabel("Emotion class score", fontsize=font_size)
        ax1.set_xlabel("λ", fontsize=font_size)
        ax1.get_legend()#.remove()
        ax1.legend(fontsize=legend_font_size)
        ax1.grid()
        fig.tight_layout()
        fig.savefig(f'/localdata2/dial_mo/loki/emex-emotion-explanation-in-ai/plots/eval/goEmo/contrastive_steering_{setting}_{emotion}_subjective.pdf')
        plt.clf()

def individual_plots():
    for idx, csvfile in enumerate(csv_files):
        df = pd.read_csv(csvfile, delimiter=';')
        for emotion in basic_emotions:
            df_emotion = df[df['emotion'] == emotion]
            df_neutral = df_emotion[df_emotion['steering_method'] == 'contrastive-neutral'].reset_index(drop=True)
            df_ovr = df_emotion[df_emotion['steering_method'] != 'contrastive-neutral'].reset_index(drop=True)
            
            fig, (ax1, ax2) = plt.subplots(1, 2, constrained_layout=True)
            fig.suptitle(f'Steering \"{df_neutral["prompt"][0]}\"\n towards {emotion}')

            for emo in basic_emotions_w_neutral:
                ax1.plot(df_neutral['lambda'], df_neutral[emo], label=emo)
                ax2.plot(df_ovr['lambda'], df_ovr[emo], label=emo)

            ax1.set_title("Contrastive-Neutral")
            ax2.set_title("Contrastive-OVR")
            ax1.set_xlabel(r"\lambda")
            ax1.set_ylabel('Emotion Classifier Score')
            ax2.set_xlabel('Lambda')
            ax2.set_ylabel('Emotion Classifier Score')
            ax1.legend()
            ax2.legend()
            plt.savefig(f'/localdata2/dial_mo/loki/emex-emotion-explanation-in-ai/plots/eval/goEmo/Go_Emo_{emotion}_{idx}.png')


# individual_plots()
mean_plots()
