import glob
import pickle

def load_steering_vectors(STEERING_VECTOR_PATH = "/localdata1/EmEx/model_weights/alpaca_7b/old_steering_vectors/20-04_24-04_trainings"):

    # STEERING_VECTOR_PATH = "/localdata1/EmEx/model_weights/alpaca_7b/old_steering_vectors/20-04_24-04_trainings"
    STEERING_VECTOR_FILES = glob.glob(f'{STEERING_VECTOR_PATH}/*')

    layer_15_16_17 =  [x for x in STEERING_VECTOR_FILES if '15, 16, 17' in x]
    # layer_18_19_20_21_22 =  [x for x in STEERING_VECTOR_FILES if '18, 19, 20, 21, 22' in x]
    INSERTION_LAYERS = [15,16,17]
    positive = 0
    negative = 0
    steering_vectors = []

    for file in layer_15_16_17:
        with open(file, 'rb') as f:
            a = pickle.load(f)
            for key, value in a.items():
                target_sentence = key
                steering_vector = value[0]
                activations = value[1]
                loss = value[2]
                epoch = value[3]
                gen_text = value[4]
                label = value[5]
                steering_vectors.append([steering_vector, target_sentence, epoch, loss, gen_text, label])

                if label:
                    positive += 1
                #     positive.append([steering_vector, target_sentence, epoch, loss])
                #     # positive.append([steering_vector, activations, loss, epoch, target_sentence])
                else:
                    negative += 1
                #     negative.append([steering_vector, target_sentence, epoch, loss])
                #     # negative.append([steering_vector, activations, loss, epoch, target_sentence])

    print(f"Number of positive samples: {positive}")
    print(f"Number of negative samples: {negative}")
    # print(f"Number of steering vectors with training loss < 5: {len(steering_vectors)}")

    return steering_vectors