
import torch
import torch.nn as nn
from torch.nn import functional as F

class ClusterLearner(nn.Module):
	def __init__(self, model, optimizer, scheduler):
		super(ClusterLearner, self).__init__()
		self.model = model
		self.optimizer = optimizer
		self.scheduler=scheduler
	
	def compute_kl_loss(self, p, q):

		p_loss = F.kl_div(F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none')
		q_loss = F.kl_div(F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none')

		p_loss = p_loss.sum(dim=-1)
		q_loss = q_loss.sum(dim=-1)

		loss = (p_loss + q_loss) / 2
		return loss

	def forward(self, inputs):
		embd0 = self.model.get_sentence_encoding(inputs[0])
		embd1 = self.model.get_sentence_encoding(inputs[1])
		embd2 = self.model.get_sentence_encoding(inputs[2])
		embd3 = self.model.get_sentence_encoding(inputs[3])

		# Instance Loss
		feat0 = F.normalize(self.model.head(embd0), dim=1)
		feat1 = F.normalize(self.model.head(embd1), dim=1)
		feat2 = F.normalize(self.model.head(embd2), dim=1)
		feat3 = F.normalize(self.model.head(embd3), dim=1)
		a0_loss=F.pairwise_distance(feat0,feat1)
		a1_loss=F.pairwise_distance(feat0,feat2)
		a2_loss=F.pairwise_distance(feat0,feat3)
		batch_size=embd0.shape[0]
		ent_y = torch.Tensor([-1]*batch_size).to(a0_loss.device)
		contrastive_loss_1 = F.margin_ranking_loss(
                    a0_loss, a1_loss, ent_y, margin=0, reduction="none")
		contrastive_loss_2 = F.margin_ranking_loss(
                    a1_loss, a2_loss, ent_y, margin=2, reduction="none")
		instance_loss=((contrastive_loss_1+contrastive_loss_2).sum())/(3.0*batch_size)
		
		# Cluster Loss
		feat0 = F.softmax(F.normalize(self.model.cluster_head(embd0), dim=1), dim=1)
		feat1 = F.softmax(F.normalize(self.model.cluster_head(embd1), dim=1), dim=1)
		feat2 = F.softmax(F.normalize(self.model.cluster_head(embd2), dim=1), dim=1)
		feat3 = F.softmax(F.normalize(self.model.cluster_head(embd3), dim=1), dim=1)
		kl_1 = self.compute_kl_loss(feat0,feat1)
		kl_2 = self.compute_kl_loss(feat0,feat2)
		kl_3 = self.compute_kl_loss(feat0,feat3)
	
		batch_size=embd0.shape[0]
		ent_y = torch.Tensor([-1]*batch_size).to(kl_1.device)
		cluster_loss_1 = F.margin_ranking_loss(
                    kl_1, kl_2, ent_y, margin=0, reduction="none")
		cluster_loss_2 = F.margin_ranking_loss(
                    kl_2, kl_3, ent_y, margin=2, reduction="none")

		cluster_loss=((cluster_loss_1+cluster_loss_2).sum())/(3.0*batch_size)

		total_loss=instance_loss+cluster_loss
		total_loss.backward()
		nn.utils.clip_grad_norm_(self.model.parameters(), 5)
		self.optimizer.step()
		self.optimizer.zero_grad()
		self.scheduler.step()
		return {"Instance-CL_loss":instance_loss.detach().item(), "clustering_loss":cluster_loss.detach().item(), "total_loss":total_loss.detach().item()}

