
import unittest
import torch
from emb2emb.hausdorff import hausdorff_similarity, _local_hausdorff_similarities, local_bag_losses, get_local_hausdorff_similarities_function
import time
from itertools import product


class TestHausdorff(unittest.TestCase):

    def setUp(self):

        self.X = torch.tensor(
            [
                [
                    [1.0, 0.0],
                    [-1.0, 0.0]
                ],
                [
                    [1.0, 0.0],
                    [-1.0, 0.0]
                ],
                [
                    [1.0, 0.0],
                    [-1.0, 0.0]
                ]
            ]
        )

        self.Y = torch.tensor(
            [
                [
                    [2.0, 0.0],
                    [1.0, 1.0]
                ],
                [
                    [2.0, 0.0],
                    [1.0, -1.0]
                ],
                [
                    [2.0, -2.0],
                    [1.0, -1.0]
                ]
            ]
        )

    def test_X_Y(self):

        assert self.X.size(0) == self.Y.size(0)

        result = hausdorff_similarity(
            self.X, self.Y, similarity_function="euclidean")

        expected_result = [-1.309016994, -1.309016994, -1.618033989]
        allclose = torch.allclose(result, torch.tensor(expected_result))
        self.assertTrue(allclose)

        # with mask
        X_mask = torch.tensor([
            [
                True, False
            ],
            [
                False, True
            ],
            [
                True, True
            ]
        ])

        Y_mask = torch.tensor([
            [
                True, True
            ],
            [
                False, True
            ],
            [
                True, False
            ]

        ])
        result = hausdorff_similarity(
            self.X, self.Y, mask_X=X_mask, mask_Y=Y_mask, similarity_function="euclidean")

        expected_result = [-1., -2.236067977, -2.57845]
        allclose = torch.allclose(result, torch.tensor(expected_result))
        self.assertTrue(allclose)

    def test_same(self):

        # nondifferentiable
        result1 = hausdorff_similarity(
            self.X, self.X, similarity_function="euclidean")
        result2 = hausdorff_similarity(
            self.X, self.X, similarity_function="euclidean")

        allclose = torch.allclose(result1, torch.zeros(self.X.size(0)))
        self.assertTrue(allclose)
        allclose = torch.allclose(result2, torch.zeros(self.X.size(0)))
        self.assertTrue(allclose)

        # differentiable
        result1 = hausdorff_similarity(
            self.X, self.X, similarity_function="euclidean", differentiable=True, softmax_temp=0.1)
        result2 = hausdorff_similarity(
            self.X, self.X, similarity_function="euclidean", differentiable=True, softmax_temp=0.1)

        allclose = torch.allclose(result1, torch.zeros(self.X.size(0)))

        self.assertTrue(allclose)
        allclose = torch.allclose(result2, torch.zeros(self.X.size(0)))
        self.assertTrue(allclose)

    def test_symmetry(self):
        # non differentiable
        result1 = hausdorff_similarity(
            self.X, self.Y, similarity_function="euclidean")
        result2 = hausdorff_similarity(
            self.Y, self.X, similarity_function="euclidean")
        allclose = torch.allclose(result1, result2)
        self.assertTrue(allclose)

        # differentiable
        result1 = hausdorff_similarity(
            self.X, self.Y, similarity_function="euclidean", differentiable=True)
        result2 = hausdorff_similarity(
            self.Y, self.X, similarity_function="euclidean", differentiable=True)
        allclose = torch.allclose(result1, result2)
        self.assertTrue(allclose)

    def test_complex_implementation(self):
        result1 = hausdorff_similarity(
            self.X, self.Y, similarity_function="euclidean")
        result2 = hausdorff_similarity(
            self.X, self.Y, similarity_function="euclidean", naive=False)
        allclose = torch.allclose(result1, result2)
        self.assertTrue(allclose)

    def test_return_similarities(self):
        result1, sims = hausdorff_similarity(
            self.X, self.Y, similarity_function="euclidean", return_similarities=True)
        result2 = hausdorff_similarity(
            self.X, self.Y, similarity_function="euclidean", naive=False, similarities=sims)
        allclose = torch.allclose(result1, result2)
        self.assertTrue(allclose)

    def test_runtime(self):

        X = torch.randn((64, 30, 128))
        Y = torch.randn((64, 30, 128))

        start = time.time()
        hausdorff_similarity(
            X, Y, similarity_function="euclidean")
        end = time.time()
        print("Naive implementation took", end - start)
        start = time.time()
        hausdorff_similarity(
            X, Y, similarity_function="euclidean", naive=False)
        end = time.time()
        print("'Complex' implementation took", end - start)
        self.assertTrue(True)

    def test_local_hausdorffs(self):

        torch.autograd.set_detect_anomaly(True)
        X = torch.randn((64, 50, 512), requires_grad=True)
        X_orig = X.detach().clone()
        Y = torch.randn((64, 50, 512))
        Y_mask = torch.ones((Y.size(0), Y.size(1)), device=Y.device)

        for diff in [False, True]:
            X = X_orig.detach().clone()
            X.requires_grad = True
            X_mask = torch.zeros((X.size(0), X.size(1)), device=Y.device)

            result = _local_hausdorff_similarities(X, Y, mask_Y=Y_mask,
                                                   similarity_function="euclidean",
                                                   naive=False,
                                                   differentiable=diff,
                                                   softmax_temp=0.01,
                                                   max_X_len=-1,
                                                   detach=True)

            # check that the gradients are zero when they should be
            result.sum().backward()

            Xclone = X_orig.detach().clone()
            Xclone.requires_grad = True
            for i in range(X_orig.size(1)):
                X_mask[:, i] = 1.
                X_mask = X_mask.clone().detach()
                res2, _ = hausdorff_similarity(Xclone, Y, mask_X=X_mask, mask_Y=Y_mask, similarity_function="euclidean", naive=False,
                                               differentiable=diff, softmax_temp=0.01, similarities=None, return_similarities=True)
                res2.sum().backward()

                for j in range(i + 1, result.size(1)):
                    self.assertTrue((Xclone.grad[:, j] == 0).all())
                res1 = result[:, i]
                allclose = torch.allclose(res1, res2)
                self.assertTrue(allclose)

                Xclone2 = X_orig.detach().clone()
                Xclone2.requires_grad = True
                res3, _ = hausdorff_similarity(Xclone2, Y, mask_X=X_mask, mask_Y=Y_mask, similarity_function="euclidean", naive=False,
                                               differentiable=diff, softmax_temp=0.01, similarities=None, return_similarities=True)
                res3.sum().backward()
                self.assertTrue((Xclone2.grad[:, i] == X.grad[:, i]).all())

        # test whether naive and non-naive is the same when not detaching
        X = X.detach().clone()

        start = time.time()
        result1 = _local_hausdorff_similarities(X, Y, mask_Y=Y_mask,
                                                similarity_function="cosine",
                                                naive=False,
                                                differentiable=diff,
                                                softmax_temp=0.01,
                                                max_X_len=-1,
                                                detach=False,
                                                naive_local=False)
        end = time.time()
        print("Complex-local implementation took", end - start)
        start = time.time()
        result2 = _local_hausdorff_similarities(X, Y, mask_Y=Y_mask,
                                                similarity_function="cosine",
                                                naive=False,
                                                differentiable=diff,
                                                softmax_temp=0.01,
                                                max_X_len=-1,
                                                detach=False,
                                                naive_local=True)
        end = time.time()
        print("Naive-local implementation took", end - start)
        allclose = torch.allclose(result1, result2)
        self.assertTrue(allclose)

    def test_local_bag_losses(self):
        torch.autograd.set_detect_anomaly(True)
        X = torch.randn((3, 10, 32), requires_grad=True)
        X_orig = X.detach().clone()
        Y = torch.randn((3, 10, 32))
        Y_mask = torch.ones((Y.size(0), Y.size(1)), device=Y.device)

        for x in product([True, False], [True, False], [True, False], [True, False]):
            diff, naive, naive_local, detach = x

            if not (not naive_local and detach):
                X = X_orig.detach().clone()
                result1 = _local_hausdorff_similarities(X, Y, mask_Y=Y_mask,
                                                        similarity_function="euclidean",
                                                        naive=naive,
                                                        differentiable=diff,
                                                        softmax_temp=0.01,
                                                        max_X_len=-1,
                                                        naive_local=naive_local,
                                                        detach=detach)

                bag_loss_f = get_local_hausdorff_similarities_function(similarity_function="euclidean",
                                                                       naive=naive,
                                                                       differentiable=diff,
                                                                       softmax_temp=0.01,
                                                                       naive_local=naive)
                X = X_orig.detach().clone()
                result2 = local_bag_losses(
                    X, Y, Y_mask, bag_loss_f, detach, max_X_len=-1)

                allclose = torch.allclose(result1, result2)
                self.assertTrue(allclose)


if __name__ == "__main__":
    unittest.main()
