from naming_conventions import (
    languages,
    languages_lowercase,
    train_languages,
    validation_languages,
    my_order,
)
from xmljson import abdera, badgerfish
from json import dumps
import json
import matplotlib.pyplot as plt
import numpy as np
import sys
from xml.etree.ElementTree import fromstring
from collections import defaultdict

deps_per_language = defaultdict(lambda: defaultdict(lambda: 0))
avg_sent_length = defaultdict(lambda: 0)

for lan in languages:
    with open("data/ud-treebanks-v2.3/" + lan + "/stats.xml", "r") as f:
        hi = json.loads(dumps(badgerfish.data(fromstring(f.read()))))
        words = hi["treebank"]["size"]["total"]["words"]["$"]
        sents = hi["treebank"]["size"]["total"]["sentences"]["$"]
        avg_sent_length[lan] = words / sents
        for z in hi["treebank"]["deps"]["dep"]:
            deps_per_language[lan][z["@name"]] = z["$"]
            deps_per_language[lan]["total"] += z["$"]
        print(lan.split("-"), words, sents, words / sents)


avg_length = np.mean([avg_sent_length[lan] for lan in languages])
all_deps = set([dep for l in deps_per_language for dep in deps_per_language[l]])
all_deps_new = set()

d = []
l = []

avg_per_dep = []
for i, dep in enumerate(all_deps):
    occurences = sum([deps_per_language[lan][dep] for lan in languages])
    occurences_total = sum([deps_per_language[lan]["total"] for lan in languages])
    yes = avg_length * 20 * (occurences / occurences_total)
    print(i, dep, int(yes) if yes < 0.01 else round(yes, 2))
    if occurences > 60000 and dep != "total":
        d.append(dep)
        avg_per_dep.append(yes)

        l.append(yes)

        all_deps_new.add(dep)

plt.bar(d, l)
plt.xticks([])
plt.xlabel("Class (dependency label)")
plt.ylabel("Expected occurences in support set")
plt.savefig("results/deps.pdf")

all_deps = sorted(all_deps_new)
print(" & ".join([l[3:7] for l in train_languages]), "\\\\")


for dep in all_deps:
    print(
        dep,
        " & ".join(
            [
                str(
                    int(
                        avg_sent_length[lan]
                        * 20
                        * (
                            deps_per_language[lan][dep]
                            / deps_per_language[lan]["total"]
                        )
                    )
                )
                for lan in train_languages
            ]
        ),
        "\\\\",
    )

print("\n\n\n\n")

print(
    "&",
    " & ".join(
        [
            l[3:7]
            for l in languages
            if l not in train_languages and "EWT" not in l and l not in my_order[:6]
        ]
    ),
    "\\\\",
)


for dep in all_deps:
    print(
        dep[:4],
        "&",
        " & ".join(
            [
                str(
                    int(
                        avg_sent_length[lan]
                        * 20
                        * (
                            deps_per_language[lan][dep]
                            / deps_per_language[lan]["total"]
                        )
                    )
                )
                + "("
                + str(deps_per_language[lan][dep])
                + ")"
                for lan in languages
                if lan not in train_languages
                and "EWT" not in lan
                and lan not in my_order[:6]
            ]
        ),
        "\\\\",
    )


print(
    "&",
    " & ".join(
        [
            l[3:7]
            for l in languages
            if l not in train_languages and "EWT" not in l and l in my_order[:6]
        ]
    ),
    "\\\\",
)


for dep in all_deps:
    print(
        dep[:4],
        "&",
        " & ".join(
            [
                str(
                    int(
                        avg_sent_length[lan]
                        * 40
                        * (
                            deps_per_language[lan][dep]
                            / deps_per_language[lan]["total"]
                        )
                    )
                )
                + "("
                + str(deps_per_language[lan][dep])
                + ")"
                for lan in languages
                if lan not in train_languages
                and "EWT" not in lan
                and lan in my_order[:6]
            ]
        ),
        "\\\\",
    )
