import torch
import transformers
from utils import dataset_loader as dsl
from tqdm import tqdm
import pickle

PATH_TO_ACTIVATION_STORAGE = "/localdata1/EmEx/activations/"
ALPACA_WEIGHTS_FOLDER = "/localdata1/EmEx/model_weights/alpaca_7b"

# select which device you want to use
device = torch.device('cuda:1')
alpaca_model = transformers.AutoModelForCausalLM.from_pretrained(ALPACA_WEIGHTS_FOLDER).to(device)
alpaca_tokenizer = transformers.AutoTokenizer.from_pretrained(ALPACA_WEIGHTS_FOLDER)
alpaca_model.to(device)

def process_dataset(name):
    df = []
    print(f"Saving activations for {name} dataset!")
    if 'shake' in name:
        df = dsl.load_shakespeare()
    elif 'goemo' in name:
        df = dsl.load_goemo()
    elif 'yelp' in name:
        df = dsl.load_yelp()
    else:
        print(f"Didnt recognize {name}!")
        exit(-1)
    actis = []
    i = 0
    j = 0
    for index, row in tqdm(df.iterrows()):

        # removing newlines from samples.
        sentence = []
        if 'goemo' in name:
            sentence = row['text'].replace('\n', '')
        else:
            sentence = row['sample'].replace('\n', '')
        input_tokens = alpaca_tokenizer(sentence, return_tensors="pt").to(device)

        # skip samples with more than 300 tokens, otherwise GPU runs out of memory
        if len(input_tokens.input_ids) > 300: 
            continue
        gen_text = alpaca_model.forward(input_tokens.input_ids, output_hidden_states=True, return_dict=True)
        hidden_states = []

        #iterating over all layers and storing activations of the last token
        for layer in gen_text['hidden_states']:
            hidden_states.append(layer[0][-1].detach().cpu().numpy())

        # shakespeare and yelp store the labels in column 'sentiment', go emotion stores labels in 'labels' column.
        if 'goemo' in name:
            actis.append([index, sentence, hidden_states, row['labels']])
        else:
            actis.append([index, sentence, hidden_states, row['sentiment']])

        i += 1

        # save activations in batches
        if i == 10000:
            i = 0
            with open(f'{PATH_TO_ACTIVATION_STORAGE}/{name}_activations_{j}.pkl', 'wb') as f:
                pickle.dump(actis, f)
            del actis
            del hidden_states
            actis = []
            j += 1

    # in case the number of samples is not dividable by 10000, we safe the rest
    with open(f'{PATH_TO_ACTIVATION_STORAGE}/{name}_activations_{j}.pkl', 'wb') as f:
        pickle.dump(actis, f)
    del actis
    del hidden_states
    

# Select which dataset you want to use
# process_dataset('goemo')
# process_dataset('shakespeare')
process_dataset('yelp')
