import os
import warnings
from unsupervised.utils import compute_statistics_info
from unsupervised.Trie import HatTrie
from typing import List, Dict, Tuple

import torch
from torch import nn
from torch.nn import functional as F
import torch.nn.init as init
from transformers import BertModel, BertConfig, AutoModel, AutoModelForMaskedLM, AutoConfig, PretrainedConfig, \
    RobertaModel
from transformers.models.bert.modeling_bert import BertPooler, BertOnlyMLMHead, BertPreTrainingHeads, BertLayer
#from transformers.modeling_bert import BertPooler, BertOnlyMLMHead, BertPreTrainingHeads, BertLayer
from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPooling, MaskedLMOutput
from modeling_feature import BertModel as FeatureModel
from modeling_feature import BertForPreTraining as FeatureForPreTraining
from modules.span import MeanSpanExtractor, DotProductAttentiveSpanExtractor

from arguments import DataTrainingArguments, ModelArguments
from transformers import TrainingArguments
import logging

logger = logging.getLogger(__name__)


class CondenserForPretraining(nn.Module):
    def __init__(
        self,
        bert: BertModel,
        model_args: ModelArguments,
        data_args: DataTrainingArguments,
        train_args: TrainingArguments
    ):
        super(CondenserForPretraining, self).__init__()
        self.lm = bert
        self.seg_proj = nn.Linear(768, 1)
        # self.c_head = nn.ModuleList(
        #     [BertLayer(bert.config) for _ in range(model_args.n_head_layers)]
        # )
        # self.c_head.apply(self.lm._init_weights)
        self.cross_entropy = nn.CrossEntropyLoss()
        self.mse_critic = nn.MSELoss()
        self.activation = nn.Sigmoid()

        self.model_args = model_args
        self.train_args = train_args
        self.data_args = data_args

    def forward(self, model_input, labels, is_eval: bool = False):
        attention_mask = self.lm.get_extended_attention_mask(
            model_input['attention_mask'],
            model_input['attention_mask'].shape,
            model_input['attention_mask'].device
        )

        lm_out: MaskedLMOutput = self.lm(
            **model_input,
            labels=labels['MLM'],
            output_hidden_states=True,
            return_dict=True
        )
        hiddens = lm_out.hidden_states[-1]

        # print(model_input)
        # print(labels['MLM'])
        loss = self.mlm_loss(hiddens, labels['MLM'])
        if is_eval is False:
            # print(loss)
            loss += self.seg_loss(hiddens, labels['SEG'])

        return loss

    def mlm_loss(self, hiddens, labels):
        pred_scores = self.lm.cls(hiddens)
        # print(pred_scores)
        # exit()
        masked_lm_loss = self.cross_entropy(
            pred_scores.view(-1, self.lm.config.vocab_size),
            labels.view(-1)
        )
        return masked_lm_loss

    def seg_loss(self, hiddens: torch.Tensor, labels: torch.Tensor, attention_mask: torch.Tensor = None):
        batch_size = hiddens.size(0)
        # print(hiddens.size())
        pred_scores = self.seg_proj(hiddens)
        # seg_pred_loss = self.cross_entropy(
        #     pred_scores.view(-1, 2),
        #     labels.view(-1)
        # )

        logits = self.activation(pred_scores).view(batch_size, -1)
        seg_pred_loss = self.mse_critic(logits, labels)

        # logtis = logits * attention_mask
        # seg_pred_loss = F.kl_div(F.log_softmax(logits, dim=-1), F.softmax(labels, dim=-1))

        return seg_pred_loss

    @classmethod
    def from_pretrained(
            cls, model_args: ModelArguments, data_args: DataTrainingArguments, train_args: TrainingArguments,
            *args, **kwargs
    ):
        hf_model = AutoModelForMaskedLM.from_pretrained(*args, **kwargs)
        model = cls(hf_model, model_args, data_args, train_args)
        path = args[0]
        if os.path.exists(os.path.join(path, 'model.pt')):
            logger.info('loading extra weights from local files')
            model_dict = torch.load(os.path.join(path, 'model.pt'), map_location="cpu")
            load_result = model.load_state_dict(model_dict, strict=False)
            print(load_result, flush=True)
        return model

    @classmethod
    def from_config(
            cls,
            config: PretrainedConfig,
            model_args: ModelArguments,
            data_args: DataTrainingArguments,
            train_args: TrainingArguments,
    ):
        hf_model = AutoModelForMaskedLM.from_config(config)
        model = cls(hf_model, model_args, data_args, train_args)

        return model


    def save_pretrained(self, output_dir: str):
        self.lm.save_pretrained(output_dir)
        model_dict = self.state_dict()
        hf_weight_keys = [k for k in model_dict.keys() if k.startswith('lm')]
        warnings.warn(f'omiting {len(hf_weight_keys)} transformer weights')
        for k in hf_weight_keys:
            model_dict.pop(k)
        torch.save(model_dict, os.path.join(output_dir, 'model.pt'))
        torch.save([self.data_args, self.model_args, self.train_args], os.path.join(output_dir, 'args.pt'))


class Featurer(nn.Module):
    def __init__(
        self,
        bert: FeatureModel,
        model_args: ModelArguments,
        data_args: DataTrainingArguments,
        train_args: TrainingArguments
    ):
        super(Featurer, self).__init__()
        self.bert = bert

        self.model_args = model_args
        self.train_args = train_args
        self.data_args = data_args
        self.trie, _ = HatTrie.load('trie path')

    def _pad_feature(self, seq: List, tgt_len: int, val: float = 0.0, feature_dim: int = 17):
        assert len(seq) <= tgt_len, ValueError(f'seq: {seq} is wrong')
        return seq + [[val] * feature_dim for _ in range(tgt_len - len(seq))]

    def _compute_feature(self, meta: List[str], offsets: List[List[int]], entropy_label: torch.Tensor) -> torch.Tensor:
        features = []
        tgt_len = entropy_label.size(1)
        for text, offset in zip(meta, offsets):
            statistics_info = compute_statistics_info(text, self.trie)
            feature = []
            for idx in offset:
                feature.append(statistics_info[idx])
            
            features.append(self._pad_feature([[0.0] * 17] + feature, tgt_len))

        feature_states = torch.tensor(features, dtype=torch.float32, device=torch.get_device(entropy_label))
        
        del features, statistics_info

        return feature_states

    def forward(self, model_input, labels, is_eval: bool = False):
        masked_lm_labels, entropy_label = labels['MLM'], labels['SEG']
        meta, offsets = model_input.pop('meta'), model_input.pop('offsets')
        feature_states = self._compute_feature(meta, offsets, entropy_label)
        del meta, offsets

        if is_eval is True:
            entropy_label = None
        # entropy_label = None
        # feature_states = None

        loss = self.bert(
            input_ids=model_input['input_ids'],
            attention_mask=model_input['attention_mask'],
            feature_states=feature_states,
            masked_lm_labels=masked_lm_labels,
            entropy_label=entropy_label
        )

        return loss

    @classmethod
    def from_pretrained(
            cls, model_args: ModelArguments, data_args: DataTrainingArguments, train_args: TrainingArguments,
            *args, **kwargs
    ):
        model_kwargs = kwargs.copy()
        model_kwargs.pop('config')
        hf_model = FeatureForPreTraining.from_pretrained(*args, **model_kwargs)
        model = cls(hf_model, model_args, data_args, train_args)
        path = args[0]
        if os.path.exists(os.path.join(path, 'model.pt')):
            logger.info('loading extra weights from local files')
            model_dict = torch.load(os.path.join(path, 'model.pt'), map_location="cpu")
            load_result = model.load_state_dict(model_dict, strict=False)
            print(load_result, flush=True)
        return model

    @classmethod
    def from_config(
            cls,
            config: PretrainedConfig,
            model_args: ModelArguments,
            data_args: DataTrainingArguments,
            train_args: TrainingArguments,
    ):
        hf_model = FeatureForPreTraining.from_config(config)
        model = cls(hf_model, model_args, data_args, train_args)

        return model

    def save_pretrained(self, output_dir: str):
        self.bert.save_pretrained(output_dir)
        model_dict = self.state_dict()
        hf_weight_keys = [k for k in model_dict.keys() if k.startswith('bert')]
        warnings.warn(f'omiting {len(hf_weight_keys)} transformer weights')
        for k in hf_weight_keys:
            model_dict.pop(k)
        # torch.save(model_dict, os.path.join(output_dir, 'model.pt'))
        torch.save([self.data_args, self.model_args, self.train_args], os.path.join(output_dir, 'args.pt'))


class EntropyForPretraining(nn.Module):
    def __init__(
        self,
        bert: BertModel,
        model_args: ModelArguments,
        data_args: DataTrainingArguments,
        train_args: TrainingArguments,
        mse_layers: List[int] = [3],
        trie_path: str = 'trie path'
    ):
        super(EntropyForPretraining, self).__init__()
        self.bert = bert

        self.model_args = model_args
        self.train_args = train_args
        self.data_args = data_args

        # control feature type
        self.get_t_test = False
        self.only_pmi = False

        self.feature_size = 18 if self.get_t_test is True else 17
        self.feature_size = 9 if self.only_pmi is True else self.feature_size
        self.hidden_transfer = nn.Linear(self.bert.config.hidden_size, self.feature_size)

        print(f'only pmi is: {self.only_pmi}, and feature size is: {self.feature_size}.')

        # self.span_extractor = MeanSpanExtractor(input_dim=self.bert.config.hidden_size)
        self.span_extractor = DotProductAttentiveSpanExtractor(input_dim=self.bert.config.hidden_size)

        self.cross_entropy_critic = nn.CrossEntropyLoss()
        self.mse_critic = nn.MSELoss()
        self.activation = nn.Sigmoid()

        # ## -------     control the mse layer       ---------------
        self.mse_layers = mse_layers

        # ## -------     control the n-gram statistics features      ------------
        self.trie, _ = HatTrie.load(trie_path)

        self._init_spans()
    
    def _init_spans(self) -> None:
        # init.xavier_uniform_(self.hidden_transfer.weight)
        _spans = []
        for i in range(512):
            _spans.append([i, i + 6])
        self._spans = _spans

    def _pad_feature(self, seq: List, tgt_len: int, val: float = 0.0, feature_dim: int = 17):
        assert len(seq) <= tgt_len, ValueError(f'seq: {seq} is wrong')
        return seq + [[val] * feature_dim for _ in range(tgt_len - len(seq))]

    def _compute_feature(self, meta: List[str], offsets: List[List[int]], masked_lm_labels: torch.Tensor) -> torch.Tensor:
        features = []
        tgt_len = masked_lm_labels.size(1)
        for text, offset in zip(meta, offsets):
            statistics_info = compute_statistics_info(text, self.trie, get_t_test=self.get_t_test)
            feature = []
            for idx in offset:
                if self.only_pmi is True:
                    feature.append(statistics_info[idx][:self.feature_size])
                else:
                    feature.append(statistics_info[idx])

            features.append(self._pad_feature([[0.0] * self.feature_size] + feature, tgt_len, feature_dim=self.feature_size))

        feature_states = torch.tensor(features, dtype=torch.float32, device=torch.get_device(masked_lm_labels))
        del features, statistics_info

        return feature_states

    def forward(self, model_input, labels, is_eval: bool = False):
        masked_lm_labels = labels['MLM']
        meta, offsets = model_input.pop('meta'), model_input.pop('offsets')
        mse_label = self._compute_feature(meta, offsets, masked_lm_labels)
        del meta, offsets

        attention_mask = model_input['attention_mask']
        # attention_mask = self.bert.get_extended_attention_mask(
        #     model_input['attention_mask'],
        #     model_input['attention_mask'].shape,
        #     model_input['attention_mask'].device
        # )

        mlm_out: MaskedLMOutput = self.bert(
            **model_input,
            labels=labels['MLM'],
            output_hidden_states=True,
            return_dict=True
        )

        hiddens = mlm_out.hidden_states[-1]
        loss = self.mlm_loss(hiddens, masked_lm_labels)

        if is_eval is False:
            for layer in self.mse_layers:
                hiddens = mlm_out.hidden_states[layer]
                loss += self.mse_loss(hiddens, mse_label, attention_mask)
        # loss = self.mse_loss(hiddens, mse_label, attention_mask)

        return loss

    def mlm_loss(self, hiddens, labels):
        pred_scores = self.bert.cls(hiddens)
        masked_lm_loss = self.cross_entropy_critic(
            pred_scores.view(-1, self.bert.config.vocab_size),
            labels.view(-1)
        )
        return masked_lm_loss

    def aggregate_hiddens(self, hiddens: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, hidden_size = hiddens.size()
        zero_pads = hiddens.new_zeros((batch_size, 3, hidden_size))

        origin_hiddens = hiddens
        hiddens = torch.cat([zero_pads, hiddens, zero_pads], dim=1)
        assert hiddens.size(1) == seq_len + 6

        spans = self._spans
        spans = torch.tensor(spans, requires_grad=False, device=hiddens.device).long().unsqueeze(0)
        spans = spans.expand([batch_size, seq_len, 2])

        hiddens = self.span_extractor(hiddens, origin_hiddens, spans) * mask
        # hiddens = self.span_extractor(hiddens, spans) * mask

        return hiddens

    def mse_loss(self, hiddens: torch.Tensor, labels: torch.Tensor, attention_mask: torch.Tensor = None, alpha: int = 1.0):
        batch_size = hiddens.size(0)
        attention_mask = attention_mask.squeeze().unsqueeze(-1)
        # hiddens = self.aggregate_hiddens(hiddens, attention_mask)

        features = self.hidden_transfer(hiddens) * attention_mask
        features = self.activation(features)
        mse_loss = self.mse_critic(features, labels)
        return mse_loss * alpha

    @classmethod
    def from_pretrained(
            cls,
            model_args: ModelArguments,
            data_args: DataTrainingArguments,
            train_args: TrainingArguments,
            *args, **kwargs
    ):
        hf_model = AutoModelForMaskedLM.from_pretrained(*args, **kwargs)
        model = cls(hf_model, model_args, data_args, train_args)
        path = args[0]
        if os.path.exists(os.path.join(path, 'model.pt')):
            logger.info('loading extra weights from local files')
            model_dict = torch.load(os.path.join(path, 'model.pt'), map_location="cpu")
            load_result = model.load_state_dict(model_dict, strict=False)
            print(load_result, flush=True)
        return model

    @classmethod
    def from_config(
            cls,
            config: PretrainedConfig,
            model_args: ModelArguments,
            data_args: DataTrainingArguments,
            train_args: TrainingArguments,
    ):
        hf_model = AutoModelForMaskedLM.from_config(config)
        model = cls(hf_model, model_args, data_args, train_args)

        return model

    def save_pretrained(self, output_dir: str):
        self.bert.save_pretrained(output_dir)
        model_dict = self.state_dict()
        hf_weight_keys = [k for k in model_dict.keys() if k.startswith('bert')]
        warnings.warn(f'omiting {len(hf_weight_keys)} transformer weights')
        for k in hf_weight_keys:
            model_dict.pop(k)
        # torch.save(model_dict, os.path.join(output_dir, 'model.pt'))
        # torch.save([self.data_args, self.model_args, self.train_args], os.path.join(output_dir, 'args.pt'))