import torch
from torch import nn, Tensor
from typing import Iterable, Dict
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import BertPreTrainedModel, BertModel
from torch import nn
from transformers.modeling_utils import (
    PreTrainedModel,
    apply_chunking_to_forward,
    find_pruneable_heads_and_indices,
    prune_linear_layer,
)
from transformers.adapters.model_mixin import ModelWithHeadsAdaptersMixin

# from transformers.file_utils import (
#     ModelOutput,
#     add_code_sample_docstrings,
#     add_start_docstrings,
#     add_start_docstrings_to_model_forward,
#     replace_return_docstrings,
# )
from transformers.utils import logging

logger = logging.get_logger(__name__)

# _CHECKPOINT_FOR_DOC = "bert-base-uncased"
# _CONFIG_FOR_DOC = "BertConfig"
# _TOKENIZER_FOR_DOC = "BertTokenizer"

# BERT_START_DOCSTRING = r"""

#     This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
#     methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
#     pruning heads etc.)

#     This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
#     subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
#     general usage and behavior.

#     Parameters:
#         config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
#             Initializing with a config file does not load the weights associated with the model, only the
#             configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
#             weights.
# """

# BERT_INPUTS_DOCSTRING = r"""
#     Args:
#         input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
#             Indices of input sequence tokens in the vocabulary.

#             Indices can be obtained using :class:`~transformers.BertTokenizer`. See
#             :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
#             details.

#             `What are input IDs? <../glossary.html#input-ids>`__
#         attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
#             Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:

#             - 1 for tokens that are **not masked**,
#             - 0 for tokens that are **masked**.

#             `What are attention masks? <../glossary.html#attention-mask>`__
#         token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
#             Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
#             1]``:

#             - 0 corresponds to a `sentence A` token,
#             - 1 corresponds to a `sentence B` token.

#             `What are token type IDs? <../glossary.html#token-type-ids>`_
#         position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
#             Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
#             config.max_position_embeddings - 1]``.

#             `What are position IDs? <../glossary.html#position-ids>`_
#         head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
#             Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:

#             - 1 indicates the head is **not masked**,
#             - 0 indicates the head is **masked**.

#         inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
#             Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
#             This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
#             vectors than the model's internal embedding lookup matrix.
#         output_attentions (:obj:`bool`, `optional`):
#             Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
#             tensors for more detail.
#         output_hidden_states (:obj:`bool`, `optional`):
#             Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
#             more detail.
#         return_dict (:obj:`bool`, `optional`):
#             Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
# """

# # @add_start_docstrings(
# #     """
# #     Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
# #     output) e.g. for GLUE tasks.
# #     """,
# #     BERT_START_DOCSTRING,
# # )      
class BertForNCE(ModelWithHeadsAdaptersMixin, BertPreTrainedModel):
    def __init__(self, config):
        """
        :param model: SentenceTransformer model
        :param scale: Output of similarity function is multiplied by scale value
        :param similarity_fct: similarity function between sentence embeddings. By default, cos_sim. Can also be set to dot product (and then set scale to 1)
        """
        super().__init__(config)
        self.bert = BertModel(config)
        self.config = config
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, 1)
        self.init_weights()
        
#     @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
#     @add_code_sample_docstrings(
#         tokenizer_class=_TOKENIZER_FOR_DOC,
#         checkpoint=_CHECKPOINT_FOR_DOC,
#         output_type=SequenceClassifierOutput,
#         config_class=_CONFIG_FOR_DOC,
#     )
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        #pooled_output = bert_model(attention_mask=batch['attention_mask'].view(-1, 256), input_ids =batch['input_ids'].view(-1, 256), token_type_ids=batch['token_type_ids'].view(-1, 256))[1]
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        pooled_output = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        pooled_output = pooled_output[1]
        #print(pooled_output.shape)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output).view(-1, 3)
        #final_output.append(logits.view(-1, 3))
        #print(logits.shape, logits)
        #labels = torch.Tensor([1.0, 0.0, 0.0]*int(attention_mask.shape[0]/3))
        loss = None
        if labels is not None:
            #print(labels.shape)
            labels = labels.view(-1, 3)
            loss_fct = nn.BCEWithLogitsLoss()
            #print(logits.view(-1, 3).shape, labels.view(-1, 3).shape)
            loss = loss_fct(logits, labels)
#         final_output = []
#         for i in range(len(batch['attention_mask'])):
#             pooled_output = self.model(attention_mask=batch['attention_mask'][i], input_ids =batch['input_ids'][i], token_type_ids=batch['token_type_ids'][i])[1]
#             pooled_output = self.dropout(pooled_output)
#             logits = self.classifier(pooled_output)
#             final_output.append(logits.view(-1, 3))
#         loss_fct = ()
#         logits = torch.cat(final_output, dim=0)
#         labels = torch.Tensor([1.0, 0.0, 0.0]*logits.shape[0])
#         #print(logits, labels)
#         loss = self.loss_fct(logits.view(-1, 3), labels.view(-1, 3))
#         print(loss)
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=None,
            attentions=None,
        )