import pandas as pd
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA, FactorAnalysis, FastICA
import os
import pickle
from utils import dataset_loader as dsl
import gc
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
from datasets import load_dataset


# the goemo dataset doesn't have a label for happiness, so we use joy instead. (for now)
# base_emotions = ["sadness", "happiness", "fear", "anger", "surprise", "disgust"]
# base_emotions = ["sadness", "joy", "fear", "anger", "surprise", "disgust"]
base_emotions = ["neutral"]
dataset_emotions = ['admiration',
       'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion',
       'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust',
       'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy',
       'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief',
       'remorse', 'sadness', 'surprise', 'neutral']

base_emotion_orig_index = [dataset_emotions.index(element) for element in base_emotions]

# Dataset({
#     features: ['text', 'labels', 'id'],
#     num_rows: 43410
# })
go_emo_train = load_dataset('go_emotions', split='train')

selected_training_samples = []
selected_training_samples_single_emotions = []
for idx, sample in  enumerate(go_emo_train):
    if any(num in sample['labels'] for num in base_emotion_orig_index):
        selected_training_samples.append([idx, sample])
        if len(sample['labels']) == 1:
            selected_training_samples_single_emotions.append([idx, sample])


for i in range(0, 8):
    with open(f'/localdata1/EmEx/activations/GoEmo_hidden_activations_{i}.pkl', 'rb') as f:
        print(f"processing {i}th pickle")
        actis = pickle.load(f)
        for idx, sample in enumerate(selected_training_samples_single_emotions):
            text = sample[1]['text']
            for acti in actis:
                if text in acti:
                    if len(selected_training_samples_single_emotions[idx]) == 2:
                        selected_training_samples_single_emotions[idx].append(acti[2])

pass
