import sys
from statistics import mean, median
from collections import Counter, defaultdict

new_path = 'data/ldc2016e114/data/eng'
old_path = 'data/ldc2016e27v2/data/eng'

from best.data_iterators import iter_best_files, iter_best_old_files
from best.test_split import test_ids, valid_ids

docs = [f for f in iter_best_files(new_path)
        if f.doc_id in test_ids or f.doc_id in valid_ids]

docs.extend(iter_best_old_files(old_path))

mentions_per_rel = []
mentions_per_evt = []

relation_mentions = []
event_mentions = []

for doc in docs:
    hoppers = list(doc.ere.iterdescendants('hopper'))
    assert len(hoppers) > 0

    for hop in hoppers:
        for kid in hop.getchildren():
            if kid.tag != 'event_mention':
                print(doc.doc_id, name)
            else:
                event_mentions.append(kid)

        mentions_per_evt.append(len(hop.getchildren()))

    for rel in doc.ere.iterdescendants('relation'):
        for kid in rel.getchildren():
            if kid.tag != 'relation_mention':
                print(doc.doc_id, name)
            else:
                relation_mentions.append(kid)
        mentions_per_rel.append(len(rel.getchildren()))


def _summary(x):
    print("min={} avg={:.2f} median={} max={}".format(
        min(x), mean(x), median(x), max(x)))


print("mentions per relation:")
_summary(mentions_per_rel)
print("mentions per event:")
_summary(mentions_per_evt)

print()


def _count_signatures(mentions):
    signatures = Counter(" / ".join(sorted([kid.tag for kid in mention]))
                          for mention in mentions)
    for sig, count in signatures.most_common():
        print("{:<5} {}".format(count, sig))


print("relation mention signatures and their count")
_count_signatures(relation_mentions)

print()
print("event mention signatures and their count")
_count_signatures(event_mentions)

print("common relation triggers")
triggers = Counter(trigger.text for mention in relation_mentions
                   for trigger in mention.findall('trigger'))
print(triggers.most_common(20))

print("common event triggers")
triggers = Counter(trigger.text for mention in event_mentions
                   for trigger in mention.findall('trigger'))
print(triggers.most_common(20))

print("Typed relation signatures")


rel_types = defaultdict(lambda: defaultdict(Counter))
rel_roles = set()

for doc in docs:
    for rel in doc.ere.iterdescendants('relation'):
        type_ = rel.attrib['type']
        subtype = rel.attrib['subtype']

        for mention in rel.getchildren():

            role1 = mention.find('rel_arg1').attrib['role']
            role2 = mention.find('rel_arg2').attrib['role']
            signature = "{} / {}".format(role1, role2)
            rel_roles.update([role1, role2])
            rel_types[type_][subtype].update([signature])


print("relation types and signatures")
for type_, subdict in rel_types.items():
    print(type_)
    print('===')
    for subtype, sigs in subdict.items():
        print("\t", subtype)
        for sig, count in sigs.most_common():
            print("\t\t", count, sig)
    print()


print("{} types, {} subtypes, {} roles".format(
    len(rel_types),
    sum(len(subtypes) for _, subtypes in rel_types.items()),
    len(rel_roles)))


evt_types = defaultdict(lambda: defaultdict(Counter))
evt_roles = set()

for doc in docs:
    for em in doc.ere.iterdescendants('event_mention'):
        type_ = em.attrib['type']
        subtype = em.attrib['subtype']
        roles = [x.attrib['role'] for x in em.getchildren()
                 if x.tag == 'em_arg']
        roles = sorted(roles)
        signature = " / ".join(roles)
        evt_roles.update(roles)
        evt_types[type_][subtype].update([signature])


print("event types and signatures")
for type_, subdict in evt_types.items():
    print(type_)
    print('===')
    for subtype, sigs in subdict.items():
        print("\t", subtype)
        for sig, count in sigs.most_common():
            print("\t\t", count, sig)
    print()


print("{} types, {} subtypes, {} roles".format(
    len(evt_types),
    sum(len(subtypes) for _, subtypes in evt_types.items()),
    len(evt_roles)))
