import torch

from transformers import BertTokenizerFast
from colbert.modeling.tokenization.utils import _split_into_batches, _sort_by_length

from colbert.modeling.tokenization.utils import _split_into_batches_doc #!@ custom

class DocTokenizer():
    def __init__(self, doc_maxlen):
        self.tok = BertTokenizerFast.from_pretrained('bert-base-uncased')
        self.doc_maxlen = doc_maxlen

        self.D_marker_token, self.D_marker_token_id = '[D]', self.tok.convert_tokens_to_ids('[unused1]')
        self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id
        self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id

        assert self.D_marker_token_id == 2

    def tokenize(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        tokens = [self.tok.tokenize(x, add_special_tokens=False) for x in batch_text]

        if not add_special_tokens:
            return tokens

        prefix, suffix = [self.cls_token, self.D_marker_token], [self.sep_token]
        tokens = [prefix + lst + suffix for lst in tokens]

        return tokens

    def encode(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], (type(batch_text))

        ids = self.tok(batch_text, add_special_tokens=False)['input_ids']

        if not add_special_tokens:
            return ids

        prefix, suffix = [self.cls_token_id, self.D_marker_token_id], [self.sep_token_id]
        ids = [prefix + lst + suffix for lst in ids]

        return ids

    # def tensorize(self, batch_text, bsize=None):
    #     assert type(batch_text) in [list, tuple], (type(batch_text))

    #     # add placehold for the [D] marker
    #     batch_text = ['. ' + x for x in batch_text]

    #     obj = self.tok(batch_text, padding='longest', truncation='longest_first',
    #                    return_tensors='pt', max_length=self.doc_maxlen)

    #     ids, mask = obj['input_ids'], obj['attention_mask']

    #     # postprocess for the [D] marker
    #     ids[:, 1] = self.D_marker_token_id

    #     if bsize:
    #         ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize)
    #         batches = _split_into_batches(ids, mask, bsize)
    #         return batches, reverse_indices

    #     return ids, mask

    #!@ custom
    def tensorize(self, batch_tuples, bsize=None):
        # batch_tuples: List[Tuple(str, List[int])] = List of tuples-- passage string and list of top-k token positions for pruned document indexing and reduced Q-D interaction
        assert type(batch_tuples) in [list, tuple], (type(batch_tuples))

        # add placehold for the [D] marker
        batch_text = ['. ' + x[0] for x in batch_tuples]
        obj = self.tok(batch_text, padding='longest', truncation='longest_first',
                       return_tensors='pt', max_length=self.doc_maxlen)

        ids, mask = obj['input_ids'], obj['attention_mask']

        # postprocess for the [D] marker
        ids[:, 1] = self.D_marker_token_id

        #?@ debugging
        # print(f'batch_tuples[0]={batch_tuples[0]}')
        # print(f'batch_tuples[1]={batch_tuples[1]}')

        #!@ custom
        if len(batch_tuples[0][1])>0: # prune_tokens==True
            pruned_index = [_[1] for _ in batch_tuples]
            pruning_mask = torch.zeros_like(mask)
            for i_x, x in enumerate(pruned_index):
                pruning_mask[i_x, x] = 1
        else:
            pruning_mask = mask.clone()
        
        #?@ debugging
        # print(f'\n ID=0')
        # print(f'batch_text={batch_text[0]}')
        # print(f'ids={ids[0]}')
        # print(f'mask={mask[0]}')
        # # print(f'pruned_index={pruned_index[0]}')
        # print(f'pruning_mask={pruning_mask[0]}')
        # print(f'\n ID=-1')
        # print(f'batch_text={batch_text[-1]}')
        # print(f'ids={ids[-1]}')
        # print(f'mask={mask[-1]}')
        # # print(f'pruned_index={pruned_index[-1]}')
        # print(f'pruning_mask={pruning_mask[-1]}')
        # print(f'DocTokenizer: tensorize: (debugging) exit');exit()

        if bsize:
            # ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize) #!@ original
            ids, mask, pruning_mask, reverse_indices = _sort_by_length(ids, mask, pruning_mask, bsize) #!@ custom
            
            # batches = _split_into_batches(ids, mask, bsize) #!@ original
            batches = _split_into_batches_doc(ids, mask, pruning_mask, bsize=bsize) #!@ custom

            return batches, reverse_indices

        return ids, mask, pruning_mask
