import torch
import torch.nn as nn
import utils.init as init
from .bertCRF import CRF
from transformers import AutoModel
# from ZEN import ZenModel
from LEBERT.wcbert_modeling import WCBertModel
from LEBERT.function.utils import build_pretrained_embedding_for_corpus
from modules.rnn import RNN


class LexiconModel(nn.Module):
    def __init__(
        self,
        bert_path,
        label_vocab,
        d_model: int = 768,
        for_ner: bool = False,
        use_feature: bool = False
    ) -> None:
        super(LexiconModel, self).__init__()

        self.d_model = d_model

        self.label_vocab = label_vocab
        self.num_tags = label_vocab.max_n_words
        self.num_layers = 1
        self.word_embeddings = None

        if use_feature is True:
            self.bert = WCBertModel.from_pretrained(bert_path)
        else:
            self.bert = WCBertModel.from_pretrained(bert_path)

        self.dropout_layer = nn.Dropout(0.2)
        self.CRF = CRF(num_tags=self.num_tags, input_dim=self.d_model)

        self.nofeature = True
        print('#### model nofeature: ', self.nofeature)

    def build_lexicon_embedding(self, hp, word_vocab):
        pretrained_embeddings, embedding_dim = build_pretrained_embedding_for_corpus(
            embedding_path=hp.lexicon_embedding_path,
            word_vocab=word_vocab,
            max_scan_num=hp.max_scan_num,
            saved_corpus_embedding_dir=hp.saved_lexicon_embedding_dir
        )
        word_vocab_size = pretrained_embeddings.shape[0]
        embed_dim = pretrained_embeddings.shape[1]
        self.word_embeddings = nn.Embedding(word_vocab_size, embed_dim)
        self.word_embeddings.weight.data.copy_(torch.from_numpy(pretrained_embeddings))
        print("Load pretrained embedding from file.........")

    def init_model(self, param_path=None, device=None):
        if param_path is not None:
            self._load_param(param_path, device)
        # self._add_cur_pos_linear()

    def _save_pos_params(self, path):
        pass

    def _load_param(self, path, device):
        state_dict = torch.load(path, map_location='cuda:' + str(device))
        self.pos_linear.load_state_dict(state_dict)
        print(f'load pos params from {path} into device cuda: {device}')

    def _add_cur_pos_linear(self):
        print(f'origin pos params size: {len(self.pos_linear)}')
        linear = nn.Parameter(init.default_init(torch.zeros((self.d_model, self.d_model), dtype=torch.float32)))
        self.pos_linear = nn.ParameterList([linear]).extend(self.pos_linear)
        print(f'successful add a pos param at the first of the paramsList, cur size: {len(self.pos_linear)}')

    def forward(
        self,
        seqs,
        mask,
        labels,
        matched_word_ids=None,
        matched_word_mask=None,
        boundary_ids=None,
    ):
        batch_size, seq_len = seqs.size()

        matched_word_embeddings = self.word_embeddings(matched_word_ids)
        outputs = self.bert(
            input_ids=seqs,
            attention_mask=mask,
            matched_word_embeddings=matched_word_embeddings,
            matched_word_mask=matched_word_mask,
            boundary_ids=boundary_ids,
            return_dict=False
        )

        ctx = outputs[0]
        ctx = self.dropout_layer(ctx)
        res_dic = self.CRF(ctx, mask, labels)

        return res_dic
