import sys
from conllu import parse_incr

from utils import Meter


if __name__ == '__main__':
    assert sys.argv[-1] == '-v'
    gold_trees = list(parse_incr(open(sys.argv[-3], encoding='utf-8')))
    pred_trees = list(parse_incr(open(sys.argv[-2], encoding='utf-8')))
    if len(sys.argv) == 5:
        assert sys.argv[-4].startswith('-i') or sys.argv[-4].startswith('--ignore-punct')
        ignore_punct = True
    else:
        ignore_punct = False

    uas, las = Meter(), Meter()
    for i, gold_tree in enumerate(gold_trees):
        pred_tree = pred_trees[i]
        for j, gw in enumerate(gold_tree):
            pw = pred_tree[j]
            assert gw['form'] == pw['form']
            if (gw['head'] is not None) and ((gw['upos'] != '.' and gw['upos'] != 'PUNCT') or not ignore_punct):
                uas_info = 1 if gw['head'] == pw['head'] else 0
                las_info = (1 if gw['deprel'] == pw['deprel'] else 0) and uas_info
                uas.update(uas_info)
                las.update(las_info)
    uas = uas.average * 100
    las = las.average * 100
    print(f'UAS | {uas:.02f} ')
    print(f'LAS | {las:.02f} ')
