import torch
import torch.nn as nn
import torch.nn.functional as F
from loguru import logger
from .SlotAttention import SlotAttention
from torch.autograd import Variable
from .FFNet import FFNet
from transformers import AutoTokenizer, AutoModel

torch.utils.backcompat.broadcast_warning.enabled = True
torch.set_printoptions(threshold=5000)


class VariableSlotModel(nn.Module): 
    def __init__(self, device, model_name, is_baseline, dimension_size, num_iters):
        super(VariableSlotModel, self).__init__()
        self.device = device
        self.model_name = model_name
        self.is_baseline = is_baseline

        self.sentence_bert_model = AutoModel.from_pretrained(model_name)

        self.slot_att_var = SlotAttention(1, dimension_size)

        self.slot_att = SlotAttention(1, dimension_size, iters=num_iters)

        self.ff = FFNet(dimension_size)

        self.mlp = nn.Sequential(
            nn.Linear(dimension_size, dimension_size),
            nn.ReLU(inplace=True),
            nn.Linear(dimension_size, dimension_size),
        )
        self.self_attention = nn.MultiheadAttention(
            dimension_size,
            8,
            dropout=0.5,
            add_bias_kv=True,
        )

        self.norm_1 = nn.LayerNorm(dimension_size)
        self.norm_2 = nn.LayerNorm(dimension_size)
       
        self.bilstm = nn.LSTM(
            dimension_size,
            int(dimension_size/2),
            1,
            bidirectional=True,
            batch_first=True,
        )


        self.ff_embedding = FFNet(dimension_size, dimension_size)

        self.dropout = nn.Dropout(0.6)

        self.linear_output = nn.Linear(dimension_size, 1)
        self.linear_output_2 = nn.Linear(dimension_size, 1)

        self.dense = nn.Linear(dimension_size, dimension_size)
        self.activation = nn.Tanh()

    def forward(
        self,
        sentence_input_ids,
        sentence_token_type_ids,
        sentence_attention_mask,
        bert_pos_var,
        bert_pos_exp,
        all_exp_pos,
    ):
    


        # def attend_to_variables(word_embeddings, position_variables):
        #     position_variables = position_variables.view(-1, 256, 1)
        #     lots_zeros = torch.zeros_like(word_embeddings).to(self.device)
        #     largest_value = torch.max(position_variables.view(-1), 0)[0]

        #     new_word_embeddings = word_embeddings
        #     for num_var in range(1, largest_value + 1):
        #         word_embedding_without_vars = torch.where(
        #             position_variables == num_var,
        #             lots_zeros,
        #             word_embeddings,
        #         )

        #         variable_slots = self.slot_att(word_embedding_without_vars).repeat(
        #             (1, word_embeddings.shape[1], 1)
        #         )
        #         new_word_embeddings = torch.where(
        #             position_variables == num_var, variable_slots, new_word_embeddings
        #         )

        #     new_word_embeddings = self.mlp(new_word_embeddings)

        #     new_word_embeddings = self.dropout(new_word_embeddings)

        #     return new_word_embeddings


        def attend_to_variables(word_embeddings, position_variables,attention_mask):
            position_variables = position_variables.view(-1, 256, 1)
            attention_mask = attention_mask.view(-1, 256, 1)
            lots_zeros = torch.zeros_like(word_embeddings).to(self.device)
            largest_value = torch.max(position_variables.view(-1), 0)[0]

            word_embedding_without_vars = torch.where(
                    position_variables == 0,
                    word_embeddings,
                    lots_zeros,
            )
            word_embedding_without_vars = torch.where(
                    attention_mask == 1,
                    word_embedding_without_vars,
                    lots_zeros,
            )


            new_word_embeddings =  word_embedding_without_vars.clone()
            for num_var in range(1, largest_value + 1):
                variable_slots = self.slot_att(new_word_embeddings).repeat(
                    (1, word_embeddings.shape[1], 1)
                )
                new_word_embeddings = torch.where(
                    position_variables == num_var, variable_slots, new_word_embeddings
                )

            new_word_embeddings = torch.where(
                    attention_mask == 0,
                    word_embeddings,
                    new_word_embeddings,
            )



            self.bilstm.flatten_parameters()
            output_lstm, (hn, cn) = self.bilstm(new_word_embeddings)
            
            
       
            # new_word_embeddings = self.mlp(new_word_embeddings)


            new_word_embeddings = self.dropout(output_lstm)

            return new_word_embeddings



        def mean_pooling(token_embeddings, attention_mask):

            input_mask_expanded = (
                attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            )
         
            sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)

            sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)

            return sum_embeddings / sum_mask

        if "roberta" in self.model_name or "paraphrase" in self.model_name :
            sentence_output_sbert = self.sentence_bert_model(
                input_ids=sentence_input_ids,
                attention_mask=sentence_attention_mask,
            )

        else:
            sentence_output_sbert = self.sentence_bert_model(
                input_ids=sentence_input_ids,
                token_type_ids=sentence_token_type_ids,
                attention_mask=sentence_attention_mask,
            )

            output_embedding = sentence_output_sbert[0]
        if self.is_baseline:
            sentence_embedding = mean_pooling(
                sentence_output_sbert[0], sentence_attention_mask
            )
            out = self.linear_output(sentence_embedding)
        else:
            attended_with_vars_sentence = attend_to_variables(
                sentence_output_sbert[0], bert_pos_var, sentence_attention_mask
            )
            output_embedding = attended_with_vars_sentence

            #attended_with_vars_sentence = self_attention_embedding(attended_with_vars_sentence)
            sentence_embedding = mean_pooling(
                attended_with_vars_sentence, sentence_attention_mask
            )
            out = self.linear_output_2(sentence_embedding)

        
        
        probs = torch.sigmoid(out)
        preds = probs > 0.5

        return out, preds

    def get_embedding(
        self,
        sentence_input_ids,
        sentence_token_type_ids,
        sentence_attention_mask,
        bert_pos_var,
        bert_pos_exp,
        all_exp_pos,
    ):
        def attend_to_variables(word_embeddings, position_variables):
            position_variables = position_variables.view(-1, 256, 1)
            lots_zeros = torch.zeros_like(word_embeddings).to(self.device)
            largest_value = torch.max(position_variables.view(-1), 0)[0]

            new_word_embeddings = word_embeddings
            for num_var in range(1, largest_value + 1):
                word_embedding_without_vars = torch.where(
                    position_variables == num_var,
                    lots_zeros,
                    word_embeddings,
                )

                variable_slots = self.slot_att(word_embedding_without_vars).repeat(
                    (1, word_embeddings.shape[1], 1)
                )
                new_word_embeddings = torch.where(
                    position_variables == num_var, variable_slots, new_word_embeddings
                )

            new_word_embeddings = self.mlp(new_word_embeddings)

            new_word_embeddings = self.dropout(new_word_embeddings)

            return new_word_embeddings

        if "roberta" in self.model_name:
            sentence_output_sbert = self.sentence_bert_model(
                input_ids=sentence_input_ids,
                attention_mask=sentence_attention_mask,
            )

        else:
            sentence_output_sbert = self.sentence_bert_model(
                input_ids=sentence_input_ids,
                token_type_ids=sentence_token_type_ids,
                attention_mask=sentence_attention_mask,
            )

        word_embeddings = sentence_output_sbert[0]

        if not self.is_baseline:

            attended_with_vars_sentence = attend_to_variables(
                sentence_output_sbert[0], bert_pos_var
            )
            word_embeddings = attended_with_vars_sentence

        return word_embeddings
