import os
import random
from transformers import *

"""
For dial and tweet corpus
"""


# positive: 1, negative: 0
# afro-american: 1, white: 0
def get_labeled_data(pos_pos, pos_neg, neg_pos, neg_neg, test_s, val_s, train_s, max_length=92):
    x_train = []
    x_test = []
    x_val = []
    max_train, max_val, max_test = -1, -1, -1
    for x in pos_pos[:train_s]:
        if len(x.split(' ')) < max_length:
            x_train.append('{}\t{}\t{}\n'.format(x, 1, 1))
            max_train = max(max_train, len(x.split(' ')))
    for x in pos_pos[train_s:train_s + test_s]:
        if len(x.split(' ')) < max_length:
            x_test.append('{}\t{}\t{}\n'.format(x, 1, 1))
            max_test = max(max_test, len(x.split(' ')))
    for x in pos_pos[train_s + test_s:train_s + test_s + val_s]:
        if len(x.split(' ')) < max_length:
            x_val.append('{}\t{}\t{}\n'.format(x, 1, 1))
            max_val = max(max_val, len(x.split(' ')))

    print(len(x_train), len(x_test), len(x_val))

    for x in pos_neg[:train_s]:
        if len(x.split(' ')) < max_length:
            x_train.append('{}\t{}\t{}\n'.format(x, 1, 0))
            max_train = max(max_train, len(x.split(' ')))
    for x in pos_neg[train_s:train_s + test_s]:
        if len(x.split(' ')) < max_length:
            x_test.append('{}\t{}\t{}\n'.format(x, 1, 0))
            max_test = max(max_test, len(x.split(' ')))
    for x in pos_neg[train_s + test_s:train_s + test_s + val_s]:
        if len(x.split(' ')) < max_length:
            x_val.append('{}\t{}\t{}\n'.format(x, 1, 0))
            max_val = max(max_val, len(x.split(' ')))

    print(len(x_train), len(x_test), len(x_val))

    for x in neg_pos[:train_s]:
        if len(x.split(' ')) < max_length:
            x_train.append('{}\t{}\t{}\n'.format(x, 0, 1))
            max_train = max(max_train, len(x.split(' ')))
    for x in neg_pos[train_s:train_s + test_s]:
        if len(x.split(' ')) < max_length:
            x_test.append('{}\t{}\t{}\n'.format(x, 0, 1))
            max_test = max(max_test, len(x.split(' ')))
    for x in neg_pos[train_s + test_s:train_s + test_s + val_s]:
        if len(x.split(' ')) < max_length:
            x_val.append('{}\t{}\t{}\n'.format(x, 0, 1))
            max_val = max(max_val, len(x.split(' ')))

    print(len(x_train), len(x_test), len(x_val))

    for x in neg_neg[:train_s]:
        if len(x.split(' ')) < max_length:
            x_train.append('{}\t{}\t{}\n'.format(x, 0, 0))
            max_train = max(max_train, len(x.split(' ')))
    for x in neg_neg[train_s:train_s + test_s]:
        if len(x.split(' ')) < max_length:
            x_test.append('{}\t{}\t{}\n'.format(x, 0, 0))
            max_test = max(max_test, len(x.split(' ')))

    for x in neg_neg[train_s + test_s:train_s + test_s + val_s]:
        if len(x.split(' ')) < max_length:
            x_val.append('{}\t{}\t{}\n'.format(x, 0, 0))
            max_val = max(max_val, len(x.split(' ')))

    print(len(x_train), len(x_test), len(x_val))

    print('Max Length train ', max_train)
    print('Max Length val ', max_val)
    print('Max Length test ', max_test)
    return x_train, x_test, x_val


tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
import codecs


def open_file(file):
    count = 0
    with open(file, 'rb') as file:
        lines = file.readlines()
    cleaned_lines = []
    for line in lines:
        try:
            cleaned_lines.append(codecs.decode(line, 'UTF-8'))
        except:
            count += 1
    print('Coddec errors', count // len(lines))
    print('Number of lines lefts', len(lines))

    # lines = [tokenizer.decode(tokenizer.encode(codecs.decode(line, 'UTF-8')),skip_special_tokens=True) for line in lines]
    lines = [line.replace('\n', '') for line in cleaned_lines]
    return lines


if __name__ == '__main__':
    train_proportion = 0.87
    val_proportion = 0.03
    test_proportion = 0.1
    SEED = 45

    suffix = 'age'  # 'gender' #'mention_race'  # mention_race sentiment_race
    dataset = 'tweet'  # 'dial'
    new_folder_path = os.path.join(dataset, '{}_{}_processed'.format(dataset, suffix))
    folder_path = os.path.join(dataset, suffix)

    os.makedirs(new_folder_path, exist_ok=True)
    random.seed(SEED)
    pos_pos = open_file(os.path.join(folder_path, 'pos_pos'))
    pos_neg = open_file(os.path.join(folder_path, 'pos_neg'))
    neg_pos = open_file(os.path.join(folder_path, 'neg_pos'))
    neg_neg = open_file(os.path.join(folder_path, 'neg_neg'))
    random.shuffle(pos_pos)
    random.shuffle(pos_neg)
    random.shuffle(neg_pos)
    random.shuffle(neg_neg)
    length = min(len(pos_pos), len(pos_neg), len(neg_pos), len(neg_neg))
    test_s, val_s, train_s = int(length * test_proportion), int(length * val_proportion), int(length * train_proportion)
    x_train, x_test, x_val = get_labeled_data(pos_pos, pos_neg, neg_pos, neg_neg, test_s, val_s, train_s)

    random.shuffle(x_train)
    random.shuffle(x_test)
    random.shuffle(x_val)

    with open(os.path.join(new_folder_path, 'x_test'), 'w') as file:
        file.write('text\tmain attribute\tprivate attribute\n')
        file.writelines(x_test)
    with open(os.path.join(new_folder_path, 'x_val'), 'w') as file:
        file.write('text\tmain attribute\tprivate attribute\n')
        file.writelines(x_val)
    with open(os.path.join(new_folder_path, 'x_train'), 'w') as file:
        file.write('text\tmain attribute\tprivate attribute\n')
        file.writelines(x_train)
