import json
import argparse
import csv
import ast
import itertools

from collections import defaultdict


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input', type=str, required=True)
    parser.add_argument('--output', type=str, required=True)
    args = parser.parse_args()

    return args


def preprocess_crows_pairs():
    data = []

    with open('data/orig_cp.csv') as f:
        reader = csv.DictReader(f)
        for row in reader:
            example = {}
            direction = row['stereo_antistereo']
            example['bias_type'] = row['bias_type']

            example['stereotype'] = row['sent_more']
            example['anti-stereotype'] = row['sent_less']
            example['bias_score'] = sum([1 for annotation in ast.literal_eval(row['annotations']) if len(annotation) > 0])
            data.append(example)

    return data


def preprocess_stereoset():
    data = []

    with open('data/orig_ss.json') as f:
        input = json.load(f)
        for annotations in input['data']['intrasentence']:
            example = {}
            example['bias_type'] = annotations['bias_type']
            for annotation in annotations['sentences']:
                gold_label = annotation['gold_label']
                sentence = annotation['sentence']
                example[gold_label] = sentence
            data.append(example)

    return data


def preprocess_sssb_gender():
    data = []
    tmp = []

    with open('data/orig_sssb_gender.txt') as f:
        for l in f:
            if l == '\n':
                example = {}
                # print('tmp: ', tmp)
                example['pos'] = tmp[0][1]
                example['sense'] = tmp[0][2].split('%')[0].strip()
                if tmp[0][3] == ' stereo':
                    example['stereotype'] = tmp[0][0]
                    example['anti-stereotype'] = tmp[1][0]
                elif tmp[0][3] == ' anti': 
                    example['stereotype'] = tmp[1][0]
                    example['anti-stereotype'] = tmp[0][0]
                # if tmp[0][3] == 'stereo':
                #     example['stereotype'] = tmp[0][0]
                #     example['anti-stereotype'] = tmp[1][0]
                # else:
                #     example['stereotype'] = tmp[1][0]
                #     example['anti-stereotype'] = tmp[0][0]
                data.append(example)
                tmp = []
            else:
                sentence, labels = l.strip().split('[')
                sentence = sentence.strip()
                labels = labels[:-1].split(',')
                tmp.append([sentence] + labels)

    return data


def preprocess_sssb_nationality_or_race(file_name):
    data = []
    tmp = {'anti': defaultdict(list), 'stereo': defaultdict(list)}

    with open(file_name) as f:
        for l in f:
            if l != '\n':
                sentence, labels = l.strip().split('[')
                sentence = sentence.strip()
                labels = [label.strip().split('%')[0] for label in labels[:-1].split(',')]
                tmp[labels[2]][f'{labels[0]}\t{labels[1]}'].append(sentence)

    antis = tmp['anti']
    stereos = tmp['stereo']
    senses = list(antis.keys())
    for sense in senses:
        anti_sentences = antis[sense]
        stereo_sentences = stereos[sense]
        for anti, stereo in itertools.product(anti_sentences, stereo_sentences):
            example = {}
            example['pos'] = sense.split('\t')[0]
            example['sense'] = sense.split('\t')[1]
            example['stereotype'] = stereo
            example['anti-stereotype'] = anti
            data.append(example)

    return data


def main(args):

    if args.input == 'crows_pairs':
        data = preprocess_crows_pairs()
    elif args.input == 'stereoset':
        data = preprocess_stereoset()
    elif args.input == 'sssb_gender':
        data = preprocess_sssb_gender()
    elif args.input == 'sssb_nationality':
        data = preprocess_sssb_nationality_or_race('data/orig_sssb_nationality.txt')
    elif args.input == 'sssb_race':
        data = preprocess_sssb_nationality_or_race('data/orig_sssb_race.txt')

    with open(args.output, 'w') as fw:
        json.dump(data, fw, indent=4)


if __name__ == "__main__":
    args = parse_args()
    main(args)
