#import config
import numpy as np
import json
import sys
processed_persona_path = 'processed_persona'
threshold = 5

def get_processed_persona(kind,require_label = True):
    #processed_persona_path = config.processed_persona
    if(require_label):
        path = processed_persona_path + '/%s_merged_shuffle_8clusters.txt' % kind
    else:
        path = processed_persona_path + '/%s.txt' % kind
    with open(path, 'r') as f:
        data = json.load(f)
    return data

def check_repeat(list_a,list_b):
    for a in list_a:
        if(a in list_b):
            return True
    for b in list_b:
        if(b in list_a):
            return True    
    return False

def save_data(data,  kind='train'):
    # dev here used to indicate zero shot setting
    assert kind in ['train','dev','test']
    save_path =  processed_persona_path + '/label_split_8cluster/%s_merged.txt' % kind
    with open(save_path, 'w', encoding='utf-8') as fout:
        json.dump(data, fout)


def get_few_sample(occur_list, val_data,threshold=5):
    append_index = []
    for id,sample in enumerate(val_data):
        sample_labels = sample['labels']
        
        sample_len = len(sample['labels'])
        count = 0
        if(max(sample_labels) >= 0):
            append = True
            for sample_label in sample_labels:
                if(sample_label != -1 and occur_list[sample_label] > threshold):
                    append = False
                    break
                
            if(append):
                append_index.append(id)
    return append_index


def split_val(few_val_list,val_data):
    remove_list = []
    shuffle_list = []
    #remove list for storing few shot samples
    temp_list = list(set(few_val_list))
    print(f'temp_list:{len(temp_list)}')
    for i, dict in enumerate(val_data):
        if(i in temp_list):
            remove_list.append(dict)
        else:
            shuffle_list.append(dict)

    return remove_list,shuffle_list




chosen_persona_toid_path = processed_persona_path + '/persona2id_8clusters.txt'
with open(chosen_persona_toid_path, 'r') as f:
    chosen_persona_toid = json.load(f)
num_labels = len(chosen_persona_toid)

# index: label id  value: occurance of label id in train data
occur_list = np.zeros(num_labels)
#training only
#occur dict key:label value: label shown up in training data id i (list of ids)
occur_dict = {i:[] for i in range(num_labels)}

train_data = get_processed_persona('train')
print('train: ',len(train_data))
val_data = get_processed_persona('dev')
print('val: ',len(val_data))
test_data = get_processed_persona('test') 
print('test: ',len(test_data))

#make train data includes all data
train_data.extend(val_data)
train_data.extend(test_data)


cutoff_thresh = 2
test_split = []
train_split = []
all_zero_split = []
for i,dict_i in enumerate(train_data):
    label_list = dict_i['labels']
    test_append = False
    if(max(label_list) >= 0):
        for label in label_list:
            if(label >= 0):
                occur_list[label] += 1
                occur_dict[label].append(i)
                if(label <= cutoff_thresh):
                    test_append = True
            
        if test_append:
            test_split.append(dict_i)
        else:
            train_split.append(dict_i)
    else:
        all_zero_split.append(dict_i)



train_split.extend(all_zero_split)
val_split = test_split[0:500]

test_split = test_split[500:]



if(check_repeat(train_split,test_split)):
    print('train & test has repeated samples')

if(check_repeat(train_split,val_split)):
    print('train & val has repeated samples')

if(check_repeat(test_split,val_split)):
    print('test & val has repeated samples')

print('train_split: ',len(train_split))
print('val_split: ',len(val_split))
print('test_split: ',len(test_split))
print('all_zero_split: ',len(all_zero_split))

save_data(train_split,  kind='train')
save_data(val_split,  kind='dev')
save_data(test_split,  kind='test')


#print('remove from train: ' , len(train_remove))
#print(train_data[remove_index[0]])