
from datasets.wmt21 import load_human, load_metric
from datasets.spot_the_bot import load_annotated, load_additional

from emnlp2022.showcase_appendix import WMT_SYSTEMS
from emnlp2022.showcase_appendix import STB_SYSTEMS


def wmt():
    print("WMT")

    wmt_human = load_human()
    wmt_metric = load_metric()

    print("N HUMAN")
    for system in WMT_SYSTEMS:
        h = wmt_human[system]
        print(system, h.binary_scores.shape[0])
    print()

    print("N bleurt")
    bleurt = wmt_metric["bleurt-20-ref-C"]
    for system in WMT_SYSTEMS:
        m = bleurt[system]
        print(system, m.scores.shape[0])
    print()

    print("N sentBLEU")
    sent_bleu = wmt_metric["sentBLEU-ref-C"]
    for system in WMT_SYSTEMS:
        m = sent_bleu[system]
        print(system, m.scores.shape[0])

    print()


def stb():
    print("STB")

    annotated = load_annotated()
    additional = load_additional()

    human = annotated['convai2']['human']
    metric = annotated['convai2']['metric']
    additional = additional['convai2']

    print("N Human")
    for bot in STB_SYSTEMS:
        h = human[bot]
        print(bot, h.binary_scores.shape[0])
    print()

    print("N usr_ret")
    m_usr = metric["usr_ret"]
    add_usr = additional["usr_ret"]
    for bot in STB_SYSTEMS:
        m = m_usr[bot]
        add = add_usr[bot]
        nm = m.scores.shape[0]
        nadd = add.scores.shape[0]
        print(bot, f"{nm} + {nadd} = {nm + nadd}")

    print()


if __name__ == "__main__":
    wmt()
    print('*' * 20)
    stb()
