from transformers import PreTrainedModel
from typing import Iterable, Dict
import torch
import torch.nn.functional as F
from torch import nn, Tensor
import numpy as np
from sklearn.metrics.pairwise import paired_cosine_distances

class SiameseDistanceMetric:
    @staticmethod
    def EUCLIDEAN(x, y):
        return F.pairwise_distance(x, y, p=2)

    @staticmethod
    def MANHATTAN(x, y):
        return F.pairwise_distance(x, y, p=1)

    @staticmethod
    def COSINE_DISTANCE(x, y):
        return 1 - F.cosine_similarity(x, y)



class OnlineContrastiveLoss(nn.Module):
    """
    Online Contrastive loss. Selects hard positive (positives that are far apart)
    and hard negative pairs (negatives that are close) and computes the loss only for these pairs.

    :param distance_metric: Function that returns a distance between two embeddings. The class SiameseDistanceMetric contains pre-defined metrics that can be used
    :param margin: Negative samples (label == 0) should have a distance of at least the margin value.

    Example::

        from torch import nn, Tensor

        model = SentenceTransformer('all-MiniLM-L6-v2')
        embeddings = model.encode(['This is a positive pair', 'Where the distance will be minimized'])

        distance_metric = SiameseDistanceMetric.COSINE_DISTANCE
        margin = 0.5
        labels = Tensor([1])
        
        loss_fn = OnlineContrastiveLoss(distance_metric=distance_metric, margin=margin)
        loss = loss_fn(embeddings, labels)
    """

    def __init__(self, distance_metric=SiameseDistanceMetric.COSINE_DISTANCE, margin: float = 0.5):
        super(OnlineContrastiveLoss, self).__init__()
        self.margin = margin
        self.distance_metric = distance_metric

    def forward(self, embeddings1: Iterable[Tensor], embeddings2: Iterable[Tensor], labels: Tensor):
        """
        Forward pass of the Online Contrastive Loss.

        :param embeddings1: Iterable of tensors representing the embeddings of sentence list 1
        :param embeddings2: Iterable of tensors representing the embeddings of sentence list 2
        :param labels: Tensor of labels indicating positive (1) or negative (0) pairs
        :return: Loss value
        """
        # Normalize embeddings
        embeddings1 = F.normalize(embeddings1, p=2, dim=1)
        embeddings2 = F.normalize(embeddings2, p=2, dim=1)

        # Compute pairwise distances/similarities
        distance_matrix = self.distance_metric(embeddings1, embeddings2)
        
        negs = distance_matrix[labels == 0]
        poss = distance_matrix[labels == 1]

        # Select hard positive and hard negative pairs
        negative_pairs = negs[negs < (poss.max() if len(poss) > 1 else negs.mean())]
        positive_pairs = poss[poss > (negs.min() if len(negs) > 1 else poss.mean())]

        positive_loss = positive_pairs.pow(2).sum()
        negative_loss = F.relu(self.margin - negative_pairs).pow(2).sum()
        loss = positive_loss + negative_loss
        return loss
