
import os
import unittest

from src.pipeline import proposed_evaluation, official_conll_evaluation
from tests.utils import get_test_folders, compare_folders


TEST_DIR = os.path.dirname(os.path.abspath(__file__))

################################################################################
################################################################################


class TestEvaluationConll09(unittest.TestCase):

    def setUp(self):
        """Data folder is the top-level `data` folder, so we do not use the
        input folder structure like the other tests.
        """
        self.data_folder = 'data'
        self.test_folder = os.path.join(TEST_DIR, 'data')

    def test_eval_conll09_ood_brown(self):
        """Tests the out-of-domain CoNNL 2009 data.
        """
        expected_folder, input_folder, output_folder = get_test_folders(self.test_folder, 'conll09-brown')

        gold_conllu_fp = os.path.join(self.data_folder, 'conll09.brown.test.conllu')
        pred_conllu_fp = os.path.join(self.data_folder, 'conll09.brown.pred.roberta-base.conllu')

        official_conll_evaluation(gold_conllu_fp, pred_conllu_fp, output_folder)
        proposed_evaluation(gold_conllu_fp, pred_conllu_fp, output_folder)

        compare_folders(expected_folder, output_folder)

    def test_eval_conll09_indomain_wsj(self):
        """Tests the in-domain CoNNL 2009 data.
        """
        expected_folder, input_folder, output_folder = get_test_folders(self.test_folder, 'conll09-wsj')

        gold_conllu_fp = os.path.join(self.data_folder, 'conll09.wsj.test.conllu')
        pred_conllu_fp = os.path.join(self.data_folder, 'conll09.wsj.pred.roberta-base.conllu')

        official_conll_evaluation(gold_conllu_fp, pred_conllu_fp, output_folder)
        proposed_evaluation(gold_conllu_fp, pred_conllu_fp, output_folder)

        compare_folders(expected_folder, output_folder)

################################################################################
################################################################################


if __name__ == '__main__':
    unittest.main()
