#import config
import numpy as np
import json
import sys
processed_persona_path = 'processed_persona'
threshold = 5

def get_processed_persona(kind,require_label = True):
    #processed_persona_path = config.processed_persona
    if(require_label):
        path = processed_persona_path + '/%s_merged.txt' % kind
    else:
        path = processed_persona_path + '/%s.txt' % kind
    with open(path, 'r') as f:
        data = json.load(f)
    return data

def check_repeat(list_a,list_b):
    for a in list_a:
        if(a in list_b):
            return True
    for b in list_b:
        if(b in list_a):
            return True    
    return False

def save_data(data,  kind='train'):
    # dev here used to indicate zero shot setting
    assert kind in ['train','dev','test']
    save_path =  processed_persona_path + '/zero_shot/%s_merged.txt' % kind
    with open(save_path, 'w', encoding='utf-8') as fout:
        json.dump(data, fout)


def get_few_sample(occur_list, val_data,threshold=5):
    append_index = []
    for id,sample in enumerate(val_data):
        sample_labels = sample['labels']
        
        sample_len = len(sample['labels'])
        count = 0
        if(max(sample_labels) >= 0):
            append = True
            for sample_label in sample_labels:
                if(sample_label != -1 and occur_list[sample_label] > threshold):
                    append = False
                    break
                
            if(append):
                append_index.append(id)
    return append_index


def split_val(few_val_list,val_data):
    remove_list = []
    shuffle_list = []
    #remove list for storing few shot samples
    temp_list = list(set(few_val_list))
    print(f'temp_list:{len(temp_list)}')
    for i, dict in enumerate(val_data):
        if(i in temp_list):
            remove_list.append(dict)
        else:
            shuffle_list.append(dict)

    return remove_list,shuffle_list




chosen_persona_toid_path = processed_persona_path + '/persona2id.txt'
with open(chosen_persona_toid_path, 'r') as f:
    chosen_persona_toid = json.load(f)
num_labels = len(chosen_persona_toid)

# index: label id  value: occurance of label id in train data
occur_list = np.zeros(num_labels)
#training only
#occur dict key:label value: label shown up in training data id i (list of ids)
occur_dict = {i:[] for i in range(num_labels)}

train_data = get_processed_persona('train')
print('train: ',len(train_data))
val_data = get_processed_persona('dev')
print('val: ',len(val_data))
test_data = get_processed_persona('test') 
print('test: ',len(test_data))

for i,dict_i in enumerate(train_data):
    label_list = dict_i['labels']
    if(max(label_list) >= 0):
        for label in label_list:
            if(label >= 0):
                occur_list[label] += 1
                occur_dict[label].append(i)

few_val_list = get_few_sample(occur_list,val_data)
few_test_list = get_few_sample(occur_list,test_data)

print(f'few_val_list:{len(few_val_list)}')
print(f'few_test_list:{len(few_test_list)}')

chosen_label = np.where(occur_list>=threshold)[0]

few_label1 = np.where(occur_list==1)[0]
few_label2 = np.where(occur_list==2)[0]
few_label3 = np.where(occur_list==3)[0]
few_label4 = np.where(occur_list==4)[0]
few_label0 = np.where(occur_list==0)[0]

few_label = np.concatenate((few_label1, few_label2, few_label3,few_label4), axis=0)
print(chosen_label.shape) #2745
print(occur_list)  
print(few_label.shape) # 704


# recollect samples from train
remove_index = []
for label in chosen_label:
    assert len(occur_dict[label]) >= threshold
    for i in range(threshold,len(occur_dict[label])):
        sample = train_data[occur_dict[label][i]]
        sample_labels = sample['labels']
        remove = True
        for sample_label in sample_labels:
            if(sample_label != -1 and occur_list[sample_label] < threshold):
                #print('Sample id: ',occur_dict[label][i],' has label', sample_label, 'with occurrance:', occur_list[sample_label])
                #print(sample_labels)
                remove = False
                break

        if(remove):
            for sample_label in sample_labels:
                if(sample_label != -1):
                    occur_list[sample_label] -= 1
            remove_index.append(occur_dict[label][i])

# choosing zero shot / few shot label setting
fewshot_index = []
for label in few_label:
    assert len(occur_dict[label]) <= 4
    sample = train_data[occur_dict[label][0]]
    sample_labels = sample['labels']
    remove = True
    for sample_label in sample_labels:
        if(sample_label != -1 and occur_list[sample_label] > 5):
            remove = False
            break
    if(remove):
        for sample_label in sample_labels:
            if(sample_label != -1):
                occur_list[sample_label] -= 1
        fewshot_index.append(occur_dict[label][0])
        


print(f'len(remove_index): {len(remove_index)}') # 3054
print(f'len(fewshot_index): {len(fewshot_index)}')
#print(remove_index)
#sys.exit(-1)

remove_index = list(set(remove_index))
fewshot_index = list(set(fewshot_index))
print(len(remove_index)) 
remove_train = []
train_shuffle = []
few_shot_list = []

for i, dict in enumerate(train_data):
    if(i in fewshot_index):
        few_shot_list.append(dict)
    elif(i in remove_index):
        remove_train.append(dict)
    else:
        train_shuffle.append(dict)

print('few_shot_list: ', len(few_shot_list))
remove_list_val, shuffle_list_val = split_val(few_val_list,val_data)
remove_list_test, shuffle_list_test = split_val(few_test_list,test_data)


print('train_shuffle: ', len(train_data))
test_suffle = remove_train[0:1000]
#test_suffle = remove_train[1000:2000]
train_shuffle.extend(remove_train[1000:])
train_shuffle.extend(shuffle_list_val)
train_shuffle.extend(shuffle_list_test)

few_shot_list.extend(remove_list_val)
few_shot_list.extend(remove_list_test)

print('train_shuffle: ', len(train_shuffle))
#print('val_shuffle: ', len(val_shuffle))
print('test_suffle: ', len(test_suffle))
print('few_shot_list: ', len(few_shot_list))


# print('train_shuffle\n ', train_shuffle[-1])
# print('val_shuffle\n ', val_shuffle[-1])
# print('test_suffle\n ', test_suffle[-1])
if(check_repeat(train_shuffle,few_shot_list)):
    print('train & val has repeated samples')

if(check_repeat(train_shuffle,test_suffle)):
    print('train & test has repeated samples')

if(check_repeat(test_suffle,few_shot_list)):
    print('test & val has repeated samples')

save_data(train_shuffle,  kind='train')
save_data(few_shot_list,  kind='dev')
save_data(test_suffle,  kind='test')


#print('remove from train: ' , len(train_remove))
#print(train_data[remove_index[0]])