import string
import torch
import torch.nn as nn

from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
from pruner.parameters import DEVICE


class ColBERTPruner(BertPreTrainedModel):
    def __init__(self, config, doc_maxlen, mask_punctuation, dim=128):

        super().__init__(config)

        self.doc_maxlen = doc_maxlen
        self.dim = dim

        self.mask_punctuation = mask_punctuation
        self.skiplist = {}

        if self.mask_punctuation:
            self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
            self.skiplist = {w: True
                             for symbol in string.punctuation
                             for w in [symbol, self.tokenizer.encode(symbol, add_special_tokens=False)[0]]}
            
        self.bert = BertModel(config)
        self.linear = nn.Linear(config.hidden_size, dim, bias=False)
        self.prediction_layer = nn.Sequential(nn.Linear(dim, dim),nn.ReLU(),nn.Linear(dim, dim),nn.ReLU(),nn.Linear(dim, 1))
        self.init_weights()

    def forward(self, ids, mask):
        token_features = self.doc(input_ids=ids, attention_mask=mask) # B, L, dim
        return self.predict_scores(token_features) # B, L, 1 (logits before sigmoid)

    def predict_scores(self, token_features):
        return self.prediction_layer(token_features) # B, L, 1 (logits before sigmoid)

    def doc(self, input_ids, attention_mask, keep_dims=True):
        input_ids, attention_mask = input_ids.to(DEVICE), attention_mask.to(DEVICE)
        
        D_bert = self.bert(input_ids, attention_mask=attention_mask)[0]
        D = self.linear(D_bert)
        mask = torch.tensor(self.mask(input_ids), device=DEVICE).unsqueeze(2).float()
        
        D = D * mask
        D = torch.nn.functional.normalize(D, p=2, dim=2)

        return D

    def mask(self, input_ids):
        mask = [[(x not in self.skiplist) and (x != 0) for x in d] for d in input_ids.cpu().tolist()]
        return mask