import json
import os
import sys
from collections import defaultdict
import nltk
import re

PRO_FILE = " "
GENDER_FILE = " "
ENTITY_FILE = " "
FILE_IN = " "
FILE_OUT = " "
regex = '([0-9]{4})-([0-9]{2})-([0-9]{2})'
is_lower = False
competitors = {
    'he':'he', 'his':'he', 'him':'he',
    'she':'she', 'her':'she', 'hers':'she',
    'it':'it', 'its':'it',
    'we':'we', 'our':'we', 'ours':'we', 'us':'we',
    'they':'they', 'their':'they', 'theirs':'they', 'them':'they'}
        

def _get_pronoun_from_feature(gender, entity_type):
    if gender == "male":
        return "he"
    elif gender == "female":
        return "she"
    elif gender == "neutral":
        return "it"
    else:
        return None


def extend_get_pronoun(get_pronoun, gender_info, entity_info):
    updated = {}
    for entity in gender_info:
        if entity in get_pronoun:
            updated[entity] = get_pronoun[entity]
            continue
        pronoun = _get_pronoun_from_feature(gender_info[entity], entity_info[entity])
        if pronoun is not None:
            updated[entity] = pronoun
    return updated
        

def is_parallel(entity, antecedent):
    if "subj" in entity['syntax'] and "subj" in antecedent['syntax']:
        return True
    elif "obj" in entity['syntax'] and "obj" in antecedent['syntax']:
        return True
    else:
        return False


def in_focus(entity, focus):
    for antecedent in focus:
        if antecedent['entity'] == entity:
            return True
    return False


def is_competiter(entity1, entity2):
    if entity1 not in get_pronoun or entity2 not in get_pronoun:
        return False
    else:
        pronoun1 = get_pronoun[entity1]
        pronoun2 = get_pronoun[entity2]
        if competitors[pronoun1] == competitors[pronoun2]:
            return True
        else:
            return False

# seems useless
def realize_date(date):
    year, month, day = date.replace('\'', '').replace('\"', '').split('-')

    if day[-1] == '1':
        day = day + 'st'
    elif day[-1] == '2':
        day = day + 'nd'
    elif day[-1] == '3':
        day = day + 'rd'
    else:
        day = day + 'th'

    month = int(month)
    if month == 1:
        month = 'january'
    elif month == 2:
        month = 'february'
    elif month == 3:
        month = 'march'
    elif month == 4:
        month = 'april'
    elif month == 5:
        month = 'may'
    elif month == 6:
        month = 'june'
    elif month == 7:
        month = 'july'
    elif month == 8:
        month = 'august'
    elif month == 9:
        month = 'september'
    elif month == 10:
        month = 'october'
    elif month == 11:
        month = 'november'
    elif month == 12:
        month = 'december'
    else:
        month = str(month)

    return ' '.join([month, day, year])


with open(PRO_FILE) as _in:
    get_pronoun = json.load(_in)

with open(GENDER_FILE) as _in:
    gender_info = json.load(_in)

with open(ENTITY_FILE) as _in:
    entity_into = json.load(_in)

get_pronoun = extend_get_pronoun(get_pronoun, gender_info, entity_into)

with open(FILE_IN) as _in:
    raw_data = json.load(_in)

test_data = defaultdict(list)

for idx, sample in enumerate(raw_data):
    test_data[sample['eid']+sample['lid']].append({
        'pre_context': sample['pre_context'],
        'position': len(sample['pre_context']),
        'sentence': sample['sentence'],
        'refex': sample['refex'],
        'pos_context': sample['pos_context'],
        'entity': sample['entity'],
        'syntax': sample['syntax'],
        'index': idx
    })

for doc_id in test_data:
    test_data[doc_id] = sorted(test_data[doc_id], key=lambda e:e['position'])
    # add features
    for idx, sample in enumerate(test_data[doc_id]):
        entity = sample['entity']
        is_old = False
        for ss in test_data[doc_id][:idx]:
            if ss['entity'] == entity:
                is_old = True
        sample['discourse_old'] = is_old
        sample['focus'] = is_old or 'subj' in sample['syntax']

# Annotation Phase
for doc_id in test_data:
    for idx, sample in enumerate(test_data[doc_id]):
        entity = sample['entity']
        matcher = re.match(regex, entity.replace('\'', '').replace('\"', ''))
        if matcher is not None:
            sample['form'] = "date"
            continue
        if entity not in get_pronoun:
            sample['form'] = "proper"
            continue
        focus, form = [], None
        # constructing the focus set
        for ss in test_data[doc_id][:idx]:
            if ss['sentence'] == sample['sentence'] - 1 and ss['focus'] == True:
                focus.append(ss)
            if ss['sentence'] == sample['sentence'] - 1 and ss['entity'] == entity:
                # strong parallelism
                if is_parallel(sample, ss):
                    form = "pronoun"
                    break
        if len(focus) == 0:
            form = "proper"
        if form == None:
            if in_focus(entity, focus):
                has_competiter = False
                for antecedent in focus:
                    if antecedent['entity'] == entity:
                        continue
                    if is_competiter(antecedent['entity'], entity):
                        has_competiter = True
                        break
                if has_competiter:
                    form = "proper"
                else:
                    form = "pronoun"
            else:
                form = "proper"
        sample['form'] = form

num_proper, num_pronoun, num_date = 0, 0, 0
# realisation
for doc_id in test_data:
    for idx, sample in enumerate(test_data[doc_id]):
        entity = sample['entity']
        syntax = sample['syntax']
        form = sample['form']
        if form == "proper":
            num_proper += 1
            tmp = ' '.join(nltk.word_tokenize(entity.replace('\'', ' ').replace('\"', ' ').replace('_', ' ')))
            if is_lower:
                surface = ' '.join(tmp.split('_')).lower()
            else:
                surface = ' '.join(tmp.split('_'))
            sample['realised'] = surface
        elif form == "pronoun":
            num_pronoun += 1
            pronoun = get_pronoun[entity]
            if pronoun == 'he':
                if syntax == 'np-obj':
                    pronoun = 'him'
                elif syntax == 'subj-det':
                    pronoun = 'his'
            elif pronoun == 'she':
                if syntax != 'np-subj':
                    pronoun = 'her'
            elif pronoun == 'it':
                if syntax == 'subj-det':
                    pronoun = 'its'
                else:
                    pronoun = 'it'
            elif pronoun == 'they':
                if syntax == 'np-obj':
                    pronoun = 'them'
                elif syntax == 'subj-det':
                    pronoun = 'their'
            sample['realised'] = pronoun
        elif form == "date":
            num_date += 1
            sample['realised'] = realize_date(entity)
        else:
            print("pikachuuuuuuuuuuuuuuuuuuuuuu")
            print(sample)
            sys.exit(-1)

refex_result = [None] * len(raw_data)

for doc_id in test_data:
    for sample in test_data[doc_id]:
        refex_result[sample['index']] = sample['realised'] + "\n"

print(num_proper, num_pronoun, num_date)
with open(os.path.join(FILE_OUT, "rule_based_out_not_lower.json"), "w") as _out:
    json.dump(test_data, _out)

with open(os.path.join(FILE_OUT, "rule_based_out_not_lower.txt"), "w") as _out:
    _out.writelines(refex_result)