
import torch
import torch.nn as nn
import torch as th
import torch.nn.functional as F
import logging
logger = logging.getLogger(__name__)
def cosine_sim(visual_emb, text_emb) :
    """
    Calculate cosine similarity.

    Args:
        visual_emb: Visual embedding with shape (num_datapoints, dim_embedding)
        text_emb: Text embedding with shape (num_datapoints, dim_embedding)

    Returns:
        Cosine similariies with shape (num_datapoints, num_datapoints)
    """
    return visual_emb.mm(text_emb.t())


class ContrastiveLoss(nn.Module):
    """
    Regular Contrastive Loss between 2 groups of embeddings
    """
    def __init__(self, margin: float, max_violation: bool = False, norm: bool = True, use_cuda: bool = True):
        super().__init__()
        self.margin = margin
        self.sim = cosine_sim
        self.norm = norm
        self.max_violation = max_violation
        self.use_cuda = use_cuda

        self.init_scale = 100.0
        self.scale_factor = 2.0
        self.scale_window = 2000
        self.loss_scale = self.init_scale
        self.min_loss_scale = 1e-4
        self.last_not_overflow_iter = 0
    
    def forwar_each_sample(self, im, s):
        im = F.normalize(im)
        s = F.normalize(s)

        # compute image-sentence score matrix - how close is im(y) to s(x)
        scores = self.sim(im, s)
        diagonal = scores.diag().view(im.size(0), 1)
        d1 = diagonal.expand_as(scores)  # each row is diagnal elemwnt
        d2 = diagonal.t().expand_as(scores) # each column is diagnal element
        # compare every diagonal score to scores in its column
        # caption retrieval
        cost_s = (self.margin + scores - d1).clamp(min=0)
        # compare every diagonal score to scores in its row
        # image retrieval
        cost_im = (self.margin + scores - d2).clamp(min=0)

        # clear diagonals, where there is just the margin left
        mask: th.Tensor = th.eye(scores.shape[0]).bool()
        if self.use_cuda:
            mask = mask.cuda(non_blocking=True)
        cost_s = cost_s.masked_fill_(mask, 0)
        cost_im = cost_im.masked_fill_(mask, 0)
        # keep the maximum violating negative for each query
        if self.max_violation:
            cost_s = cost_s.max(1)[0]
            cost_im = cost_im.max(0)[0]

        if self.norm:
            return (cost_s.sum() + cost_im.sum()).div(im.shape[0] * s.shape[0])
        else:
            # logger.info(cost_s.div(im.shape[0] * s.shape[0]).sum())
            # logger.info(cost_im.sum())
            # logger.info(cost_s.shape)
            # logger.info(cost_s.div(im.shape[0] * s.shape[0]).max())
            return cost_s, cost_im

    def scale_loss(self, total_loss):
        
        origin_loss = total_loss.clone()
        self.last_not_overflow_iter += 1
        total_loss *= self.loss_scale
        # logger.info("debug")
        # logger.info(origin_loss)
        # logger.info(total_loss)
        while torch.isinf(total_loss.sum()):
            self.loss_scale = max(self.loss_scale / self.scale_factor, self.min_loss_scale)
            total_loss = origin_loss.clone()
            self.last_not_overflow_iter = 0
            total_loss *= self.loss_scale
            
            if self.loss_scale == self.min_loss_scale and torch.isinf(total_loss.sum()):
                total_loss *= 0
                logger.info("skip loss contrasive")
                break

            logger.info("small scale loss with {}".format(self.loss_scale))

        if self.last_not_overflow_iter != 0  and self.last_not_overflow_iter % self.scale_window == 0:
            self.loss_scale = min(self.loss_scale * self.scale_factor, self.init_scale)
            logger.info("bigger scale loss with {}".format(self.loss_scale))
            logger.info(self.last_not_overflow_iter)
            logger.info(self.scale_window)
        
        return total_loss.sum()

    def forward(self, im, s):
        """
        Inputs shape (batch, len, embed_dim), each sample in the batch cal the loss respectively

        Args:
            im: Visual embeddings (batch, len,embed_dim)
            s: Text embeddings (batch, len,embed_dim)

        Returns:
        """
        #first normalize the tensor and then reshape it to (b, dim)
        b, t, c = im.shape
        assert im.shape == s.shape
        
        total_loss = []
        for im_sample, s_sample in zip(im,  s):
            loss_s, loss_im = self.forwar_each_sample(im_sample, s_sample)
            total_loss.append(loss_s)
            total_loss.append(loss_im)
            # logger.info(loss_s.shape)
            # logger.info(loss_im.shape)

        total_loss = torch.stack(total_loss).cuda()
        
        total_loss = self.scale_loss(total_loss)

        return total_loss / b
        

contrastive_loss_config = {
    "margin": 0.2,
    "norm": False,

}

if __name__ == "__main__":
    ctr_loss = ContrastiveLoss(contrastive_loss_config['margin'], use_cuda=True, norm=False)
    audio_feats = F.normalize(torch.rand((2, 2, 678)), dim=2).reshape((-1, 678)).cuda()
    text_feats = F.normalize(torch.rand((2, 2, 678)), dim=2).reshape((-1, 678)).cuda()
    
    print(ctr_loss(audio_feats, text_feats))
