import torch
import torch.nn as nn
from transformers import BertPreTrainedModel, AutoModel


class BaselineModel(BertPreTrainedModel):
    def __init__(self, config, model_name_or_path: str,
                 padding_idx: int = 1, drop: float = 0.25):
        super().__init__(config)
        self.padding_idx = padding_idx
        self.transformer = AutoModel.from_pretrained(
            model_name_or_path, config=config)
        self.dropout = nn.Dropout(drop)
        self.relu = nn.ReLU()

        self.get_logits = torch.nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            self.relu,
            self.dropout,
            nn.Linear(config.hidden_size, 2)
        )
        self.get_concept = torch.nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            self.relu,
            self.dropout,
            nn.Linear(config.hidden_size, 2)
        )
        self.get_action = torch.nn.Sequential(
            nn.Linear(config.hidden_size, config.hidden_size),
            self.relu,
            self.dropout,
            nn.Linear(config.hidden_size, 2)
        )

        self.init_weights()

    def forward(self, batch, predict_concept_and_action: bool = False):
        result = {}
        cls_utt = self.transformer(
            input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])[1]
        result['cls'] = cls_utt
        result['logits'] = self.get_logits(cls_utt)
        if predict_concept_and_action:
            result['concept'] = self.get_concept(cls_utt)
            result['action'] = self.get_action(cls_utt)
        return result


class BaselineModelMultiply(BertPreTrainedModel):
    def __init__(self, config,
                 model_name_or_path: str,
                 padding_idx: int = 1,
                 drop: float = 0.25
                 ):
        super().__init__(config)
        self.padding_idx = padding_idx
        self.transformer = AutoModel.from_pretrained(
            model_name_or_path, config=config)
        self.dropout = nn.Dropout(drop)
        self.relu = nn.ReLU()

        self.init_weights()

    def forward(self,
                batch,
                intent_description,
                concept_description=None,
                action_description=None,
                ):
        result = {}
        bert_utt = self.transformer(
            input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])[1]
        intent_descr = self.transformer(
            input_ids=intent_description['input_ids'], attention_mask=intent_description['attention_mask'])[1]
        result['utt_embeddings'] = bert_utt
        result['intent_embeddings'] = intent_descr
        result['intent_prediction'] = torch.einsum(
            'bh,ih->bi', bert_utt, intent_descr)
        if concept_description is not None:
            concept_descr = self.transformer(
                input_ids=concept_description['input_ids'], attention_mask=concept_description['attention_mask'])[1]
            result['concept_embeddings'] = concept_descr
            result['concept_prediction'] = torch.einsum(
                'bh,ih->bi', bert_utt, concept_descr)
        if action_description is not None:
            action_descr = self.transformer(
                input_ids=action_description['input_ids'], attention_mask=action_description['attention_mask'])[1]
            result['action_embeddings'] = action_descr
            result['action_prediction'] = torch.einsum(
                'bh,ih->bi', bert_utt, action_descr)
        return result
