rm import string
import torch
import torch.nn as nn

from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
from colbert.parameters import DEVICE


class ColBERT(BertPreTrainedModel):
    def __init__(self, config,
        pseudo_query_indicator, #!@ custom
        query_maxlen, doc_maxlen, mask_punctuation, dim=128, similarity_metric='cosine'):

        super(ColBERT, self).__init__(config)

        self.query_maxlen = query_maxlen
        self.doc_maxlen = doc_maxlen
        self.similarity_metric = similarity_metric
        self.dim = dim

        self.mask_punctuation = mask_punctuation
        self.skiplist = {}

        if self.mask_punctuation:
            self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
            self.skiplist = {w: True
                             for symbol in string.punctuation
                             for w in [symbol, self.tokenizer.encode(symbol, add_special_tokens=False)[0]]}
            
            #TODO: update skiplist; "##.." 으로 시작하는 token 들은 masking (각 단어의 첫 번째 token 만 matching 에 사용)
            """
            vocab = tokenizer.get_vocab() # Dict[str:int]
            # skiplist = {w:True for t, i in vocab.items() for w in [t, i] if t.startswith("##")}
            for t, i in vocab.items():
                if t.startswith("##"):
                    # self.skiplist[t] = True
                    self.skiplist[i] = True
            """


        self.bert = BertModel(config)
        self.linear = nn.Linear(config.hidden_size, dim, bias=False)

        #!@ custom
        self.pseudo_query_indicator = pseudo_query_indicator

        self.init_weights()

    # #!@ original
    # def forward(self, Q, D):
    #     return self.score(self.query(*Q), self.doc(*D))

    #!@ custom
    def forward(self, Q, D, inbatch_negatives=False):
        return self.score(self.query(*Q), self.doc(*D), inbatch_negatives=inbatch_negatives)

    def query(self, input_ids, attention_mask):
        input_ids, attention_mask = input_ids.to(DEVICE), attention_mask.to(DEVICE)
        Q = self.bert(input_ids, attention_mask=attention_mask)[0]
        Q = self.linear(Q)

        return torch.nn.functional.normalize(Q, p=2, dim=2)

    #!@ custom ; Add ``pruning_mask`` parameter
    def doc(self, input_ids, attention_mask, pruning_mask, keep_dims=True):
        
        input_ids, attention_mask = input_ids.to(DEVICE), attention_mask.to(DEVICE)
        pruning_mask = pruning_mask.to(DEVICE) #!@ custom
        
        if not self.pseudo_query_indicator: #!@ original
            D = self.bert(input_ids, attention_mask=attention_mask)[0] 
        else: #!@ custom
            D = self.bert(input_ids, attention_mask=attention_mask, 
                token_type_ids=pruning_mask.long() #!@ custom
            )[0] 
        D = self.linear(D)
        mask = torch.tensor(self.mask(input_ids), device=DEVICE).unsqueeze(2).float()
        pruning_mask = pruning_mask.unsqueeze(2).float() #!@ custom
        
        mask = mask * pruning_mask #!@ custom

        D = D * mask
        D = torch.nn.functional.normalize(D, p=2, dim=2)

        if not keep_dims:
            D, mask = D.cpu().to(dtype=torch.float16), mask.cpu().bool().squeeze(-1)
            D = [d[mask[idx]] for idx, d in enumerate(D)]

        return D
    
    def doc_with_pseudo_query(self, input_ids, attention_mask, pruning_mask):
        
        input_ids, attention_mask = input_ids.to(DEVICE), attention_mask.to(DEVICE)
        pruning_mask = pruning_mask.to(DEVICE) #!@ custom
        
        if not self.pseudo_query_indicator: #!@ original
            D = self.bert(input_ids, attention_mask=attention_mask)[0] 
        else: #!@ custom
            D = self.bert(input_ids, attention_mask=attention_mask, 
                token_type_ids=pruning_mask.long() #!@ custom
            )[0] 
        D = self.linear(D)
        mask = torch.tensor(self.mask(input_ids), device=DEVICE)

        # apply attention mask
        D = D * mask.unsqueeze(2).float()
        # normalize features (l2)
        D = torch.nn.functional.normalize(D, p=2, dim=2)

        # 1) load features on CPU device 
        # 2) and convert 32-bit floating point to 16-bit
        D = D.cpu().to(dtype=torch.float16)

        # get list of all token features
        mask = mask.cpu().bool()
        all_token_embs = [d[mask[idx]] for idx, d in enumerate(D)]

        # get list of pseudo-query token features
        # applying pseudo-query mask
        pruning_mask = pruning_mask.cpu().bool() #!@ custom
        pseudo_query_mask = (mask & pruning_mask)
        pseudo_queries_embs = [d[pseudo_query_mask[idx]] for idx, d in enumerate(D)]

        return all_token_embs, pseudo_queries_embs

    
    def score(self, Q, D, inbatch_negatives=False):
        # Q     : (2 * bsize, query_maxlen, dim) = e.g., bsize=3 -> [Q1, Q2, Q3, Q1, Q2, Q3]
        # D     : (2 * bsize, doc_maxlen, dim)   = e.g., bsize=3 -> [P1, P2, P3, N1, N2, N3]

        if self.similarity_metric == 'cosine':
            
            if self.training:
                raise NotImplementedError

            QD = (Q @ D.permute(0, 2, 1)) # (2 * bsize, query_maxlen, doc_maxlen)
            maxsim = QD.max(2).values # (2 * bsize, query_maxlen)
            relevance = maxsim.sum(1) # (2 * bsize)
            return relevance

        assert self.similarity_metric == 'l2'
        
        if inbatch_negatives:
            
            if Q.size(0)>1: # during training
                Q = Q.unsqueeze(2).unsqueeze(1) 
                # (bsize, 1          , query_maxlen, 1         , dim)
                D = D.unsqueeze(1).unsqueeze(0) 
                # (1    , (k+1) * bsize, 1           , doc_maxlen, dim)
                QD = -1.0 * (( Q - D )**2).sum(-1)
                # (bsize, (k+1) * bsize, query_maxlen, doc_maxlen)
                maxsim = QD.max(-1).values
                # (bsize, (k+1) * bsize, query_maxlen)
                relevance = maxsim.sum(-1) # (bsize, (k+1) * bsize)
                return relevance
            
            else: # for validation (re-ranking): a single Q for multiple D (refer to `colbert/evaluation/slow.py: slow_rerank`)
                
                # Q: (1          , query_maxlen  , dim)
                # D: (n_candiates, doc_maxlen    , dim)

                # Used for re-rank evaluation (refer to `colbert/evaluation/slow.py: slow_rerank`)
                QD = -1.0 * ((Q.unsqueeze(2) - D.unsqueeze(1))**2).sum(-1) # (n_candiates, query_maxlen, doc_maxlen)
                maxsim = QD.max(-1).values # (n_candiates, query_maxlen)
                relevance = maxsim.sum(-1) # (n_candiates)
                return relevance
        
        else:
            #!@ original: pair-wise negatives
            # Used for re-rank evaluation (refer to `colbert/evaluation/slow.py: slow_rerank`)
            QD = -1.0 * ((Q.unsqueeze(2) - D.unsqueeze(1))**2).sum(-1) # (n_candiates, query_maxlen, doc_maxlen)
            maxsim = QD.max(-1).values # (n_candiates, query_maxlen)
            relevance = maxsim.sum(-1) # (n_candiates)
            return relevance

    def mask(self, input_ids):
        mask = [[(x not in self.skiplist) and (x != 0) for x in d] for d in input_ids.cpu().tolist()]
        return mask
