
import unittest
import torch
from emb2emb.gmm import _gmm_kl_variational


class TestGMM(unittest.TestCase):

    def setUp(self):

        self.X_means = torch.randn((4, 10, 16))
        self.X_sigma = torch.ones((4, 10, 16))
        self.X_weights = torch.ones((4, 10)) / 10.

        self.Y_means = torch.randn((4, 8, 16))
        self.Y_sigma = torch.ones((4, 8, 16))
        self.Y_weights = torch.ones((4, 8)) / 8.

    def test_naive_fast_equivalence(self):

        result1 = _gmm_kl_variational(
            self. X_means, self.Y_means, self.X_sigma, self.Y_sigma, self.X_weights, self.Y_weights, naive=True)

        result2 = _gmm_kl_variational(
            self.X_means, self.Y_means, self.X_sigma, self.Y_sigma, self.X_weights, self.Y_weights, naive=False)

        print("res1", result1)
        print("res2", result2)
        assert torch.isclose(result1, result2).all()


if __name__ == "__main__":
    unittest.main()
