import pandas as pd
import os
import json
from tqdm import tqdm
from collections import Counter
import random

"""
rating, gender,age,location
"""

if __name__ == '__main__':
    train_proportion = 0.87
    val_proportion = 0.03
    test_proportion = 0.1
    SEED = 45
    data_path = 'trust_pilot'
    os.makedirs(os.path.join(data_path, 'blog_gender_processed'), exist_ok=True)
    os.makedirs(os.path.join(data_path, 'blog_age_processed'), exist_ok=True)
    for name in ['train', 'test', 'valid']:
        final_dic = {
            'rating': [], 'gender': [], 'age': [], 'location': [], 'text': []
        }
        with open(os.path.join(data_path, 'raw', '{}.csv'.format(name)), 'r') as file:
            lines = file.readlines()
        lines = [line.replace('\n', '').split(',') for line in lines]

        for line in tqdm(lines, 'lines'):
            if len(line) == 5:
                text, rating, gender, age, location = line
            else:
                location = line[-1]
                age = line[-2]
                gender = line[-3]
                rating = line[-4]
                text = ' , '.join(line[:-4])

            final_dic['text'].append(text)
            final_dic['rating'].append(rating)
            final_dic['gender'].append(gender)
            final_dic['age'].append(age)
            final_dic['location'].append(location)

        print('gender', Counter(final_dic['gender']))
        print('age', Counter(final_dic['age']))
        print('rating', Counter(final_dic['rating']))
        print('location', Counter(final_dic['location']))
        print('---------------------')

        index_to_extract_m = [i for i in range(len(final_dic['gender'])) if final_dic['gender'][i] == '1']
        index_to_extract_f = [i for i in range(len(final_dic['gender'])) if final_dic['gender'][i] == '0']
        l_gender = min(len(index_to_extract_m), len(index_to_extract_f))

        x_lines = []
        for i in range(l_gender):
            x_lines.append('{}\t{}\t{}\n'.format(final_dic['text'][index_to_extract_m[i]],
                                                 final_dic['rating'][index_to_extract_m[i]], 0))
            x_lines.append('{}\t{}\t{}\n'.format(final_dic['text'][index_to_extract_f[i]],
                                                 final_dic['rating'][index_to_extract_f[i]], 1))

        with open(os.path.join(data_path, 'blog_gender_processed', '{}.txt'.format(name)), 'w') as file:
            file.write('text\tmain attribute\tprivate attribute\n')
            file.writelines(x_lines)

        index_to_extract_o = [i for i in range(len(final_dic['age'])) if final_dic['age'][i] == '1']
        index_to_extract_y = [i for i in range(len(final_dic['age'])) if final_dic['age'][i] == '0']
        l_age = min(len(index_to_extract_o), len(index_to_extract_y))

        x_lines = []
        for i in range(l_age):
            x_lines.append('{}\t{}\t{}\n'.format(final_dic['text'][index_to_extract_o[i]],
                                                 final_dic['rating'][index_to_extract_o[i]], 0))
            x_lines.append('{}\t{}\t{}\n'.format(final_dic['text'][index_to_extract_y[i]],
                                                 final_dic['rating'][index_to_extract_y[i]], 1))

        with open(os.path.join(data_path, 'blog_age_processed', '{}.txt'.format(name)), 'w') as file:
            file.write('text\tmain attribute\tprivate attribute\n')
            file.writelines(x_lines)

    print(lines)
