import json
import os
import sys
from collections import defaultdict
import nltk
import re
import random

PRO_FILE = ""
GENDER_FILE = ""
ENTITY_FILE = ""
FILE_DEV = ""
FILE_TEST = ""
FILE_OUT = ""
regex = '([0-9]{4})-([0-9]{2})-([0-9]{2})'
is_lower = False
competitors = {
    'he':'he', 'his':'he', 'him':'he',
    'she':'she', 'her':'she', 'hers':'she',
    'it':'it', 'its':'it',
    'we':'we', 'our':'we', 'ours':'we', 'us':'we',
    'they':'they', 'their':'they', 'theirs':'they', 'them':'they'}

def _get_pronoun_from_feature(gender, entity_type):
    if gender == "male":
        return "he"
    elif gender == "female":
        return "she"
    elif gender == "neutral":
        return "it"
    else:
        return None


def extend_get_pronoun(get_pronoun, gender_info, entity_info):
    updated = {}
    for entity in gender_info:
        if entity in get_pronoun:
            updated[entity] = get_pronoun[entity]
            continue
        pronoun = _get_pronoun_from_feature(gender_info[entity], entity_info[entity])
        if pronoun is not None:
            updated[entity] = pronoun
    return updated

with open(PRO_FILE) as _in:
    get_pronoun = json.load(_in)

with open(GENDER_FILE) as _in:
    gender_info = json.load(_in)

with open(ENTITY_FILE) as _in:
    entity_info = json.load(_in)

get_pronoun = extend_get_pronoun(get_pronoun, gender_info, entity_info)

def check_pronoun(refex):
    if refex.lower().strip() in ['he', 'his', 'him', 'she', 'hers', 'her', 'it', 'its', 'we', 'our', 'ours', 'us'
                               'they', 'theirs', 'them', 'their']:
        return True
    else:
        return False

def is_competiter(entity1, entity2):
    if entity1 == entity2:
        return False
    elif entity1 not in get_pronoun or entity2 not in get_pronoun:
        return False
    else:
        pronoun1 = get_pronoun[entity1]
        pronoun2 = get_pronoun[entity2]
        if competitors[pronoun1] == competitors[pronoun2]:
            return True
        else:
            return False

def realize_date(date):
    year, month, day = date.replace('\'', '').replace('\"', '').split('-')

    if day[-1] == '1':
        day = day + 'st'
    elif day[-1] == '2':
        day = day + 'nd'
    elif day[-1] == '3':
        day = day + 'rd'
    else:
        day = day + 'th'

    month = int(month)
    if month == 1:
        month = 'january'
    elif month == 2:
        month = 'february'
    elif month == 3:
        month = 'march'
    elif month == 4:
        month = 'april'
    elif month == 5:
        month = 'may'
    elif month == 6:
        month = 'june'
    elif month == 7:
        month = 'july'
    elif month == 8:
        month = 'august'
    elif month == 9:
        month = 'september'
    elif month == 10:
        month = 'october'
    elif month == 11:
        month = 'november'
    elif month == 12:
        month = 'december'
    else:
        month = str(month)

    return ' '.join([month, day, year])

def preprocessing(raw_data):
    data = defaultdict(list)
    num_pronoun = 0
    for idx, sample in enumerate(raw_data):
        data[sample['eid']+sample['lid']].append({
            'pre_context': sample['pre_context'],
            'position': len(sample['pre_context']),
            'sentence': sample['sentence'],
            'refex': sample['refex'],
            'pos_context': sample['pos_context'],
            'entity': sample['entity'],
            'syntax': sample['syntax'],
            'index': idx
        })

    for doc_id in data:
        data[doc_id] = sorted(data[doc_id], key=lambda e:e['position'])
        # add features
        for idx, sample in enumerate(data[doc_id]):
            entity = sample['entity']
            is_old = False
            for ss in data[doc_id][:idx]:
                if ss['entity'] == entity:
                    is_old = True
            sample['discourse_old'] = is_old
            sample['is_pronoun'] = check_pronoun(" ".join(sample['refex']))
    return data, len(raw_data)

def tagging_form(data, threshold):
    total, find_total, correct, num_pronoun = 0.0, 0.0, 0.0, 0.0
    for doc_id in data:
        for idx, sample in enumerate(data[doc_id]):
            if sample['is_pronoun'] == True:
                total += 1
            if sample['entity'] not in get_pronoun or sample['discourse_old'] == False:
                sample['realise_pronoun'] = False
                #if sample['is_pronoun'] == False:
                #    correct += 1
                continue
            has_competiter = False
            for ss in data[doc_id][:idx]:
                if ss['sentence'] >= sample['sentence'] - 1 and is_competiter(sample['entity'], ss['entity']):
                    has_competiter = True
                    break
            if has_competiter:
                sample['realise_pronoun'] = False
                #filtered += 1
                #if sample['is_pronoun'] == False:
                #    correct += 1
                continue
            if random.random() <= threshold: # and entity_info[sample['entity']] == "Person":
                sample['realise_pronoun'] = True
                find_total += 1
                if sample['is_pronoun'] == True:
                    correct += 1
                num_pronoun += 1
            else:
                sample['realise_pronoun'] = False
    if find_total == 0:
        precision = 0
    else:
        precision = correct / find_total
    recall = correct / total
    if precision == 0:
        F1 = 0
    else:
        F1 = 2 * precision * recall / (precision + recall)
    return data, num_pronoun, precision, recall, F1

def realisation(data):
    for doc_id in data:
        for idx, sample in enumerate(data[doc_id]):
            entity = sample['entity']
            syntax = sample['syntax']
            form = sample['realise_pronoun']
            matcher = re.match(regex, entity.replace('\'', '').replace('\"', ''))
            if matcher is not None:
                sample['realised'] = realize_date(entity)
            elif form == False:
                tmp = ' '.join(nltk.word_tokenize(entity.replace('\'', ' ').replace('\"', ' ').replace('_', ' ')))
                if is_lower:
                    surface = ' '.join(tmp.split('_')).lower()
                else:
                    surface = ' '.join(tmp.split('_'))
                sample['realised'] = surface
            elif form == True:
                pronoun = get_pronoun[entity]
                if pronoun == 'he':
                    if syntax == 'np-obj':
                        pronoun = 'him'
                    elif syntax == 'subj-det':
                        pronoun = 'his'
                elif pronoun == 'she':
                    if syntax != 'np-subj':
                        pronoun = 'her'
                elif pronoun == 'it':
                    if syntax == 'subj-det':
                        pronoun = 'its'
                    else:
                        pronoun = 'it'
                elif pronoun == 'they':
                    if syntax == 'np-obj':
                        pronoun = 'them'
                    elif syntax == 'subj-det':
                        pronoun = 'their'
                sample['realised'] = pronoun
            else:
                print("pikachuuuuuuuuuuuuuuuuuuuuuu")
                print(sample)
                sys.exit(-1)
    return data

with open(FILE_DEV) as _in:
    dev, _ = preprocessing(json.load(_in))

with open(FILE_TEST) as _in:
    test, length_test = preprocessing(json.load(_in))
    
tagged_test, num_pronoun, precision, recall, F1 = tagging_form(test, 1)
print(precision, recall, F1, num_pronoun)
realised_test = realisation(tagged_test)

with open(os.path.join(FILE_OUT, "undetermine_rule_based_out_no_lower.json"), "w") as _out:
    json.dump(realised_test, _out)

refex_result = [None] * length_test

for doc_id in realised_test:
    for sample in realised_test[doc_id]:
        refex_result[sample['index']] = sample['realised'] + "\n"

with open(os.path.join(FILE_OUT, "undetermine_rule_based_out_no_lower.txt"), "w") as _out:
    _out.writelines(refex_result)