from models import SingleModel
import torch
import torch.nn as nn
import constants as C
import pdb as pdb


class DualModel(nn.Module):
    def __init__(self, args):
        super(DualModel, self).__init__()
        self.config = args

        if self.config.model in C.AVAILABLE_MODELS:
            self.model = SingleModel(args)
        else:
            raise NotImplementedError

        if self.config.model_type == 'dual' and not self.config.consistency:
            self.siameseclass = self.add_siamese_class(self.config.n_classes)
            if self.config.additional_cls:
                self.siameseclass_para = self.add_siamese_class(2)

    def update_classifier(self, n_classes):
        self.model.update_classifier(n_classes)

    def add_siamese_class(self, n_classes):
        D_in = C.HIDDEN_DIM[self.config.model]['siamese']
        siameseclass = nn.Sequential(nn.Linear(D_in, 768), nn.Tanh(),
                                              nn.Linear(768, n_classes))
        for x in siameseclass:
            self._init_weights(x)
        return siameseclass

    def update_siamese_class(self, n_classes):
        self.siameseclass = self.add_siamese_class(n_classes)

    def _init_weights(self, module):
        """ Initialize the weights """
        if isinstance(module, nn.Linear):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self, lr_input_ids, lr_attention_mask, rl_input_ids, rl_attention_mask):
        logits_0, hidden_0, logits_para_0 = self.model(lr_input_ids, lr_attention_mask)
        logits_1, hidden_1, logits_para_1 = self.model(rl_input_ids, rl_attention_mask)

        logits = torch.cat([logits_0.unsqueeze(0), logits_1.unsqueeze(0)])
        logits_para = None
        if self.config.additional_cls:
            logits_para = torch.cat([logits_para_0.unsqueeze(0), logits_para_1.unsqueeze(0)])

        hidden_last = None
        if hidden_0 is not None:
            hidden_last = torch.cat([hidden_0.unsqueeze(0), hidden_1.unsqueeze(0)])

        if self.config.model_type == 'dual' and not self.config.consistency:
            # pdb.set_trace()
            all_hidden = torch.cat([hidden_0, hidden_1, torch.abs(hidden_0 - hidden_1)], dim=1)
            logits_2 = self.siameseclass(all_hidden)
            logits = torch.cat([logits_0.unsqueeze(0), logits_1.unsqueeze(0), logits_2.unsqueeze(0)])
            # pdb.set_trace()

            if self.config.additional_cls:
                logits_para_2 = self.siameseclass_para(all_hidden)
                logits_para = torch.cat([logits_para_0.unsqueeze(0), logits_para_1.unsqueeze(0), logits_para_2.unsqueeze(0)])

        return logits, hidden_last, logits_para
