import os
from tqdm import tqdm
from collections import Counter
import random

if __name__ == '__main__':

    datapath = 'blogs'
    with open(os.path.join(datapath, 'raw/dataset'), 'r') as file:
        lines = file.readlines()
    final_dic = {
        'topic': [], 'age': [], 'gender': [], 'text': []
    }
    for line in tqdm(lines, 'lines'):
        line = line.replace('\n', '')
        try:
            topic, age, gender, user, text = line.split('\t')
            if topic != 'None' and age != 'None' and gender != 'None' and user != 'None' and text != 'None':
                final_dic['topic'].append(topic)
                final_dic['age'].append(age)
                final_dic['gender'].append(gender)
                final_dic['text'].append(text)
        except:
            pass

    print('topic', Counter(final_dic['topic']))
    print('age', Counter(final_dic['age']))
    print('gender', Counter(final_dic['gender']))

    train_proportion = 0.87
    val_proportion = 0.03
    test_proportion = 0.1
    SEED = 45

    index_to_extract_m = [i for i in range(len(final_dic['gender'])) if final_dic['gender'][i] == 'm']
    index_to_extract_f = [i for i in range(len(final_dic['gender'])) if final_dic['gender'][i] == 'f']
    l_gender = min(len(index_to_extract_m), len(index_to_extract_f))
    index_to_extract_m = index_to_extract_m[: l_gender]
    index_to_extract_f = index_to_extract_f[: l_gender]
    random.shuffle(index_to_extract_m)
    random.shuffle(index_to_extract_f)
    test_s, val_s, train_s = int(l_gender * test_proportion), int(l_gender * val_proportion), int(
        l_gender * train_proportion)

    os.makedirs(os.path.join(datapath, 'gender_topic'), exist_ok=True)

    x_train, x_test, x_val = [], [], []
    for i in range(l_gender):
        if i > test_s + val_s:
            x_train.append('{}\t{}\t{}\n'.format(final_dic['text'][index_to_extract_m[i]],
                                                 final_dic['topic'][index_to_extract_m[i]], 0))
            x_train.append('{}\t{}\t{}\n'.format(final_dic['text'][index_to_extract_f[i]],
                                                 final_dic['topic'][index_to_extract_f[i]], 1))
        elif i > test_s:
            x_val.append('{}\t{}\t{}\n'.format(final_dic['text'][index_to_extract_m[i]],
                                               final_dic['topic'][index_to_extract_m[i]], 0))
            x_val.append('{}\t{}\t{}\n'.format(final_dic['text'][index_to_extract_f[i]],
                                               final_dic['topic'][index_to_extract_f[i]], 1))

        else:
            x_test.append('{}\t{}\t{}\n'.format(final_dic['text'][index_to_extract_m[i]],
                                                final_dic['topic'][index_to_extract_m[i]], 0))
            x_test.append('{}\t{}\t{}\n'.format(final_dic['text'][index_to_extract_f[i]],
                                                final_dic['topic'][index_to_extract_f[i]], 1))

    with open(os.path.join(datapath, 'gender_topic/train.txt'), 'w') as file:
        file.write('text\tmain attribute\tprivate attribute\n')
        file.writelines(x_train)

    with open(os.path.join(datapath, 'gender_topic/test.txt'), 'w') as file:
        file.write('text\tmain attribute\tprivate attribute\n')
        file.writelines(x_test)

    with open(os.path.join(datapath, 'gender_topic/val.txt'), 'w') as file:
        file.write('text\tmain attribute\tprivate attribute\n')
        file.writelines(x_val)

    index_to_extract_age_y = [i for i in range(len(final_dic['gender'])) if final_dic['age'][i] == '1']
    index_to_extract_age_o = [i for i in range(len(final_dic['gender'])) if final_dic['age'][i] == '3']
    l_age = min(len(index_to_extract_age_y), len(index_to_extract_age_o))
    index_to_extract_age_y = index_to_extract_age_y[: l_age]
    index_to_extract_age_o = index_to_extract_age_o[: l_age]
    random.shuffle(index_to_extract_age_o)
    random.shuffle(index_to_extract_age_y)
    test_s, val_s, train_s = int(l_age * test_proportion), int(l_age * val_proportion), int(l_age * train_proportion)
    os.makedirs(os.path.join(datapath, 'age_topic'), exist_ok=True)
    print('Train', len(x_train))
    print('Test', len(x_test))
    print('Val', len(x_val))

    x_train, x_test, x_val = [], [], []

    for i in range(l_age):

        if i > test_s + val_s:
            x_train.append('{}\t{}\t{}\n'.format(final_dic['text'][index_to_extract_age_y[i]],
                                                 final_dic['topic'][index_to_extract_age_y[i]], 0))
            x_train.append('{}\t{}\t{}\n'.format(final_dic['text'][index_to_extract_age_o[i]],
                                                 final_dic['topic'][index_to_extract_age_o[i]], 1))
        elif i > test_s:
            x_val.append('{}\t{}\t{}\n'.format(final_dic['text'][index_to_extract_age_y[i]],
                                               final_dic['topic'][index_to_extract_age_y[i]], 0))
            x_val.append('{}\t{}\t{}\n'.format(final_dic['text'][index_to_extract_age_o[i]],
                                               final_dic['topic'][index_to_extract_age_o[i]], 1))

        else:
            x_test.append('{}\t{}\t{}\n'.format(final_dic['text'][index_to_extract_age_y[i]],
                                                final_dic['topic'][index_to_extract_age_y[i]], 0))
            x_test.append('{}\t{}\t{}\n'.format(final_dic['text'][index_to_extract_age_o[i]],
                                                final_dic['topic'][index_to_extract_age_o[i]], 1))

    with open(os.path.join(datapath, 'age_topic/train.txt'), 'w') as file:
        file.write('text\tmain attribute\tprivate attribute\n')
        file.writelines(x_train)

    with open(os.path.join(datapath, 'age_topic/test.txt'), 'w') as file:
        file.write('text\tmain attribute\tprivate attribute\n')
        file.writelines(x_test)

    with open(os.path.join(datapath, 'age_topic/val.txt'), 'w') as file:
        file.write('text\tmain attribute\tprivate attribute\n')
        file.writelines(x_val)

    print('Train', len(x_train))
    print('Test', len(x_test))
    print('Val', len(x_val))
