import argparse
from nltk import Tree
from collections import defaultdict

def load_corpus(filename):
    return [Tree.fromstring(line.strip()) for line in open(filename, encoding="utf8")]

def load_discosuite(filename):
    suite = defaultdict(list)
    with open(filename, encoding="utf8") as f:
        header = f.readline().strip().split(",")
        for line in f:
            row = line.strip().split(",")
            suite[int(row[0])].append(row[1:])
    return suite


def extract_constituents(tree):
    # discbracket: tokens indexed from 0
    # this function shift the indexing to 1
    if len(tree) == 1 and type(tree[0]) == str:
    #if type(tree) == str:
        i, token = tree[0].split("=", 1)
        return set(), (int(i)+1,)
    children_constituents = [extract_constituents(c) for c in tree]
    
    all_constituents = set()
    current_span = list()
    for consts, spans in children_constituents:
        all_constituents |= consts
        current_span += list(spans)
    all_constituents.add((tree.label(), tuple(sorted(current_span))))
    return all_constituents, current_span

def is_proj(span):
    return max(span) - min(span) == len(span) - 1
    
def filter_constituents(all_consts):
    # remove non discontinuous constituents
    projective = {}
    discontinuous = {}
    for label, span in all_consts:
        if label == "ROOT":
            continue
        if is_proj(span):
            #assert(span not in projective)  -> unary rewrites
            projective[span] = label
        else:
            assert(span not in discontinuous)
            discontinuous[span] = label
            
    #assert(len(projective) + len(discontinuous) + 1 == len(all_consts))
    return projective, discontinuous

def expand_span(span_str):
    #print(span_str)
    span_list = [s.strip("][") for s in span_str.split("]")]
    span_list = [s for s in span_list if s]
    
    spans = []
    for s in span_list:
        if "-" in s:
            start, end = s.split("-")
            spans.extend(list(range(int(start), int(end)+1)))
        else:
            spans.append(int(s))
    #print(spans)
    return tuple(spans)


def main():
    LA, UA, TOT = 0, 1, 2
    eval_dict = defaultdict(lambda: [0, 0, 0])
    parser = argparse.ArgumentParser()
    parser.add_argument("discosuite")
    parser.add_argument("gold")
    parser.add_argument("pred")
    args = parser.parse_args()

    gold_trees = load_corpus(args.gold)
    pred_trees = load_corpus(args.pred)
    discosuite = load_discosuite(args.discosuite)

    for i, (g, p) in enumerate(zip(gold_trees, pred_trees)):
        # extraction of proj constituents is incorrect (but useless here)
        _, gdisc = filter_constituents(extract_constituents(g)[0])
        _, pdisc = filter_constituents(extract_constituents(p)[0])
        
        suite = discosuite[i+1]
        for item in suite:
            phenomenon = item[0]
            span = expand_span(item[1])
            
            if span not in gdisc:
                print(i+1)
                print("gold", gdisc)
                print(phenomenon, span, item[2])
                print()
            assert(span in gdisc)
            score = eval_dict[phenomenon]
            score[TOT] += 1
            if span in pdisc:
                score[UA] += 1
                if pdisc[span] == gdisc[span]:
                    score[LA] += 1
                
    for phenomenon in sorted(eval_dict):
        scores = eval_dict[phenomenon]
        la, ua, total = scores
        if phenomenon == "PH/RE":
            phenomenon = "Placeholder/repeated element"
        print(f"{phenomenon}\t{la/total*100:.1f}\t{ua/total*100:.1f}\t{total}")
    #print(eval_dict)
    #print(gdisc)
    #print(pdisc)
    #print()

main()
