import torch
import torch.nn as nn
import utils.init as init
from .bertCRF import CRF
from transformers import AutoModel
from modules.modeling_feature import BertModel as FeatureModel
from modules.rnn import RNN


class SEModel(nn.Module):
    def __init__(self, bert_path, label_vocab, d_model: int = 768, trainsets=None, for_ner: bool = False, use_feature: bool = False):
        super(SEModel, self).__init__()

        self.d_model = d_model
        self.hidden_size = 400

        self.label_vocab = label_vocab
        self.num_tags = label_vocab.max_n_words
        self.num_layers = 1

        self.bert = AutoModel.from_pretrained(bert_path)
        # self.se_linear = nn.Linear(self.bert.config.hidden_size + 17, self.bert.config.hidden_size)
        self.se_linear = nn.Linear(17, self.bert.config.hidden_size)

        self.CRF = CRF(num_tags=self.num_tags, input_dim=self.d_model)
        # self.proj = nn.Linear(2 * self.hidden_size, self.num_tags)

        self.nofeature = False
        print('#### use after se model, model nofeature: ', self.nofeature)

    def init_model(self, param_path=None, device=None):
        if param_path is not None:
            self._load_param(param_path, device)
        # self._add_cur_pos_linear()

    def _save_pos_params(self, path):
        pass

    def _load_param(self, path, device):
        state_dict = torch.load(path, map_location='cuda:' + str(device))
        self.pos_linear.load_state_dict(state_dict)
        print(f'load pos params from {path} into device cuda: {device}')

    def forward(self, seqs, mask, labels, features=None, finetune=False, pos=None):
        batch_size = seqs.size(0)
        inputs_embeds = self.bert.embeddings(input_ids=seqs)

        ## add feature before bert
        # if features is not None and self.nofeature is False:
        #     # features = self.se_linear(features)
        #     # inputs_embeds = inputs_embeds + features

        #     mix_embeds = torch.cat([inputs_embeds, features], dim=-1)
        #     inputs_embeds = self.se_linear(mix_embeds)

        bert_outs = self.bert(inputs_embeds=inputs_embeds)

        ctx = bert_outs.last_hidden_state

        ## add feature after bert
        if features is not None and self.nofeature is False:
            features = self.se_linear(features)
            ctx = ctx + features

        if finetune is False:
            ctx = ctx.detach()

        res_dic = self.CRF(ctx, mask, labels)

        return res_dic
