import torch

from pruner.modeling.colbert import ColBERTPruner
from pruner.modeling.tokenization import DocTokenizer
from pruner.utils.amp import MixedPrecisionManager
from pruner.parameters import DEVICE


class ModelInference():
    def __init__(self, pruner: ColBERTPruner, amp=False):
        assert pruner.training is False

        self.pruner = pruner
        self.doc_tokenizer = DocTokenizer(pruner.doc_maxlen)

        self.amp_manager = MixedPrecisionManager(amp)

    def doc(self, *args, to_cpu=False, **kw_args):
        with torch.no_grad():
            with self.amp_manager.context():
                D = self.pruner.doc(*args, **kw_args)
                return D.cpu() if to_cpu else D

    def scoreFromText(self, docs, return_tokembs=False):

        input_ids, attention_mask = self.doc_tokenizer.tensorize(docs)
        token_features = self.doc(input_ids, attention_mask) # B, L, dim (128)

        with torch.no_grad():
            scores = self.pruner.predict_scores(token_features) # B, L, 1
            # print(f'(1) scores {scores.size()}=\n{scores[0, :6]}')
        
        # Normalize scores
        scores = torch.cat((torch.zeros_like(scores), scores), dim=-1) # B, L, 2
        scores = torch.nn.functional.softmax(scores, dim=-1)[:, :, 1] # B, L
        # print(f'(2) scores {scores.size()}=\n{scores[0, :6]}')
        
        # Remove scores for [PAD] tokens
        def tensor_to_2dlist(input, mask):
            return [_1[_2].cpu().data.numpy().tolist() for _1, _2 in zip(input, mask.bool())]
        scores_2dlist = tensor_to_2dlist(scores, attention_mask)
        # scores_2dlist: List[ List[float] ] = 2d-list of pruner scores for each token in each document
        
        input_ids_2dlist = tensor_to_2dlist(input_ids, attention_mask)
        # input_ids_2dlist: List[ List[int] ] = 2d-list of token ids for each document

        tokens_2dlist = [self.doc_tokenizer.tok.convert_ids_to_tokens(ids=ids, skip_special_tokens=False) for ids in input_ids_2dlist]
        # tokens_2dlist: List[ List[str] ] = 2d-list of tokens for each document

        # sanity check
        for _1, _2 in zip(input_ids_2dlist, tokens_2dlist):
            assert len(_1)==len(_2)
        
        if self.pruner.mask_punctuation:
            # Masking scores for masked tokens
            IGNORED_SCORE = 0.0
            masked_scores_2dlist = []
            for scores, tokens in zip(scores_2dlist, tokens_2dlist):
                masked_scores = [
                    IGNORED_SCORE if token in self.pruner.skiplist else score
                    for score, token in zip(scores, tokens)
                ]
                masked_scores_2dlist.append(masked_scores)
                
                # #?@ debugging
                # print(f'scores={scores}')
                # print(f'tokens={tokens}')
                # print(f'masked_scores={masked_scores}')
                # exit()
                
            scores_2dlist = masked_scores_2dlist
        
        if not return_tokembs:
            return scores_2dlist, tokens_2dlist

        # truncate BERT feature, 
        # to only retain final features that were transformed by linear layer followed by l2-normalization
        token_embeddings = token_features[:, :, :self.pruner.dim] # B, L, dim (128)

        return scores_2dlist, tokens_2dlist, token_embeddings
        
