import unittest

import numpy as np

from fairseq.data.data_utils import take_percent_of_dataset


class TestDataUtils(unittest.TestCase):

    def test_take_60_percent_of_shuffled_indices(self):
        indices_shuffled = np.array([0, 3, 7, 4, 6, 5, 2, 1, 9, 8])
        indices_correct = np.array([0, 3, 4, 5, 2, 1])

        indices_returned = take_percent_of_dataset(indices_shuffled, 0.6)

        self.assertEqual(indices_returned.tolist(), indices_correct.tolist())

    def test_take_50_percent_of_ordered_indices(self):
        indices_ordered = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
        indices_correct = np.array([0, 1, 2, 3, 4])

        indices_returned = take_percent_of_dataset(indices_ordered, 0.5)

        self.assertEqual(indices_returned.tolist(), indices_correct.tolist())


if __name__ == '__main__':
    unittest.main()
