import numpy as np
import json
import sys
processed_persona_path = 'processed_persona'

def get_processed_persona(kind,require_label = True):
    #processed_persona_path = config.processed_persona
    if(require_label):
        path = processed_persona_path + '/label_split_8cluster/%s_merged.txt' % kind
    else:
        path = processed_persona_path + '/%s.txt' % kind
    with open(path, 'r') as f:
        data = json.load(f)
    return data

def check_repeat(list_a,list_b):
    for a in list_a:
        if(a in list_b):
            return True
    for b in list_b:
        if(b in list_a):
            return True    
    return False

def save_data(data,  kind='train'):
    # dev here used to indicate zero shot setting
    assert kind in ['train','dev','test']
    save_path =  processed_persona_path + '/label_split_8cluster/%s_merged.txt' % kind
    with open(save_path, 'w', encoding='utf-8') as fout:
        json.dump(data, fout)


def get_few_sample(occur_list, val_data,threshold=5):
    append_index = []
    for id,sample in enumerate(val_data):
        sample_labels = sample['labels']
        
        sample_len = len(sample['labels'])
        count = 0
        if(max(sample_labels) >= 0):
            append = True
            for sample_label in sample_labels:
                if(sample_label != -1 and occur_list[sample_label] > threshold):
                    append = False
                    break
                
            if(append):
                append_index.append(id)
    return append_index


def split_val(few_val_list,val_data):
    remove_list = []
    shuffle_list = []
    #remove list for storing few shot samples
    temp_list = list(set(few_val_list))
    print(f'temp_list:{len(temp_list)}')
    for i, dict in enumerate(val_data):
        if(i in temp_list):
            remove_list.append(dict)
        else:
            shuffle_list.append(dict)

    return remove_list,shuffle_list




chosen_persona_toid_path = processed_persona_path + '/persona2id.txt'
with open(chosen_persona_toid_path, 'r') as f:
    chosen_persona_toid = json.load(f)
num_labels = len(chosen_persona_toid)

# index: label id  value: occurance of label id in train data
occur_list = np.zeros(num_labels)
#training only
#occur dict key:label value: label shown up in training data id i (list of ids)
occur_dict = {i:[] for i in range(num_labels)}

train_data = get_processed_persona('train')
print('train: ',len(train_data))
val_data = get_processed_persona('dev')
print('val: ',len(val_data))
test_data = get_processed_persona('test') 
print('test: ',len(test_data))

#make train data includes all data
train_data.extend(val_data)
train_data.extend(test_data)
print('Total conv: ',len(train_data))


total_persona = 0
total_turn = 0
total_word = 0
total_label = 0
for i,dict_i in enumerate(train_data):
    label_list = dict_i['labels']
    part_persona = dict_i['partner_persona']
    my_persona = dict_i['your_persona']
    conv =  dict_i['conv']
    num_turn = len(conv)
    persona_num = len(part_persona) + len(my_persona)
    num_word = 0
    num_label = 0
    for idx,turn in enumerate(conv):
        num_word += len(turn.split())
    #print(dict_i)
    #sys.exit(-1)

    ### count for labels
    if(max(label_list) >= 0):
        for label in label_list:
            if(label >= 0):
                num_label += 1

    total_persona += persona_num
    total_turn += num_turn
    total_word += num_word
    total_label += num_label
print(f'total_turn:{total_turn}')
print(f'total_persona:{total_persona}')
print(f'total_label:{total_label}')
print(f'Avg turns per dialog:{total_turn/len(train_data)}')
print(f'Avg words per turn:{total_word/total_turn}')
print(f'Avg labels per dialog:{total_label/len(train_data)}')