import torch
import torch.nn as nn
import utils.init as init
from .bertCRF import CRF
from transformers import AutoModel, BertModel
from ZEN import ZenModel
from modules.modeling_feature import BertModel as FeatureModel
from modules.modeling_nezha import NezhaModel
from modules.rnn import RNN


class BaseModel(nn.Module):
    def __init__(self, bert_path, label_vocab, d_model: int = 768, trainsets=None, for_ner: bool = False, use_feature: bool = False):
        super(BaseModel, self).__init__()

        self.d_model = d_model
        self.hidden_size = 400

        self.label_vocab = label_vocab
        self.num_tags = label_vocab.max_n_words
        self.num_layers = 1

        # if use_feature is True:
        #     self.bert = FeatureModel.from_pretrained(bert_path)
        # else:

        # ## Use different BERT like model
        self.bert = BertModel.from_pretrained(bert_path)
        # self.bert = AutoModel.from_pretrained(bert_path)
        # self.bert = NezhaModel.from_pretrained(bert_path)
        # self.bert = ZenModel.from_pretrained(bert_path)

        self.CRF = CRF(num_tags=self.num_tags, input_dim=self.d_model)

        # self.proj = nn.Linear(self.d_model, self.num_tags)
        # self.metric = self.classifer_critic = nn.CrossEntropyLoss(reduction='none', ignore_index=self.label_vocab.pad())

        self.nofeature = True
        print('#### model nofeature: ', self.nofeature)

    def init_model(self, param_path=None, device=None):
        if param_path is not None:
            self._load_param(param_path, device)
        # self._add_cur_pos_linear()

    def _save_pos_params(self, path):
        pass
        # state_dict = self.pos_linear.state_dict()
        # torch.save(state_dict, path)
        # print(f'save pos params successful, save at: {path}')

    def _load_param(self, path, device):
        state_dict = torch.load(path, map_location='cuda:' + str(device))
        self.pos_linear.load_state_dict(state_dict)
        print(f'load pos params from {path} into device cuda: {device}')

    def _add_cur_pos_linear(self):
        print(f'origin pos params size: {len(self.pos_linear)}')
        linear = nn.Parameter(init.default_init(torch.zeros((self.d_model, self.d_model), dtype=torch.float32)))
        self.pos_linear = nn.ParameterList([linear]).extend(self.pos_linear)
        print(f'successful add a pos param at the first of the paramsList, cur size: {len(self.pos_linear)}')

    def forward(
        self,
        seqs,
        mask,
        labels,
        ngram_ids: torch.Tensor = None,
        ngram_positions: torch.Tensor = None,
        ngram_mask: torch.Tensor = None,
        features=None,
        finetune=False,
        pos=None
    ):
        batch_size, seq_len = seqs.size()
        if features is not None:
            if self.nofeature is True:
                features = None
            bert_outs = self.bert(input_ids=seqs, feature_states=features)
            ctx = bert_outs[0]
        elif ngram_ids is not None:
            bert_outs = self.bert(
                input_ids=seqs,
                input_ngram_ids=ngram_ids,
                ngram_position_matrix=ngram_positions,
                ngram_attention_mask=ngram_mask,
                output_all_encoded_layers=False
            )
            ctx, _ = bert_outs
        else:
            bert_outs = self.bert(input_ids=seqs)
            ctx = bert_outs.last_hidden_state
        if finetune is False:
            ctx = ctx.detach()
        res_dic = self.CRF(ctx, mask, labels)

        return res_dic

        # logits = self.proj(ctx)
        # loss = self.classifer_critic(logits.view(batch_size * seq_len, -1), labels.view(-1))
        # tags = logits.detach().max(dim=-1)[1].cpu().tolist()
        # tags_list = []
        # for b in range(batch_size):
        #     tmp = []
        #     for index, tag in enumerate(tags[b]):
        #         if mask[b, index]:
        #             tmp.append(tag)
        #     tags_list.append(tmp)
        # res_idc = {'scores': logits, 'loss': loss, 'predicted_tags': tags_list}

        # return res_idc
