import torch
from torch import nn, Tensor
from typing import Iterable, Dict
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import BertPreTrainedModel, BertModel
from torch import nn
from transformers.modeling_utils import (
    PreTrainedModel,
    apply_chunking_to_forward,
    find_pruneable_heads_and_indices,
    prune_linear_layer,
)
from transformers.adapters.model_mixin import ModelWithHeadsAdaptersMixin
from transformers.utils import logging

logger = logging.get_logger(__name__)

def dot_score(a: Tensor, b: Tensor):
    """
    Computes the dot-product dot_prod(a[i], b[j]) for all i and j.
    :return: Matrix with res[i][j]  = dot_prod(a[i], b[j])
    """
    if not isinstance(a, torch.Tensor):
        a = torch.tensor(a)

    if not isinstance(b, torch.Tensor):
        b = torch.tensor(b)

    if len(a.shape) == 1:
        a = a.unsqueeze(0)

    if len(b.shape) == 1:
        b = b.unsqueeze(0)

    return torch.diagonal(torch.mm(a, b.transpose(0, 1)), 0).unsqueeze(0)

class BertForNCEDual(ModelWithHeadsAdaptersMixin, BertPreTrainedModel):
    def __init__(self, config):
        """
        :param model: SentenceTransformer model
        :param scale: Output of similarity function is multiplied by scale value
        :param similarity_fct: similarity function between sentence embeddings. By default, cos_sim. Can also be set to dot product (and then set scale to 1)
        """
        super().__init__(config)
        self.bert = BertModel(config)
        self.config = config
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(3, 3)
        self.init_weights()
        
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        #pooled_output = bert_model(attention_mask=batch['attention_mask'].view(-1, 256), input_ids =batch['input_ids'].view(-1, 256), token_type_ids=batch['token_type_ids'].view(-1, 256))[1]
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        pooled_output = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        pooled_output = pooled_output[1]
        #print(pooled_output.shape)
        pooled_output = self.dropout(pooled_output)
        logits = torch.cat((dot_score(pooled_output[::4], pooled_output[1::4]), 
                            dot_score(pooled_output[::4], pooled_output[2::4]), 
                            dot_score(pooled_output[::4], pooled_output[3::4])), 0).transpose(0, 1)
        
        logits = self.classifier(logits).view(-1, 3)
        #final_output.append(logits.view(-1, 3))
        #print(logits.shape, logits)
        #labels = torch.Tensor([1.0, 0.0, 0.0]*int(attention_mask.shape[0]/3))
        loss = None
        if labels is not None:
            #print(labels.shape)
            labels = labels.view(-1, 3)
            loss_fct = nn.BCEWithLogitsLoss()
            #print(logits.view(-1, 3).shape, labels.view(-1, 3).shape)
            loss = loss_fct(logits, labels)

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=None,
            attentions=None,
        )