

from transformers import BertTokenizer
import torch
import random 


class Tokenization():
    """
        rewrite for BERT_realization: suppose good is replaced
            [CLS] I 'm good at it. [SEP] candidate [SEP] ... 
            good is indicated by segmentation idx
    """
    def __init__(self, lower_case=True, bert_folder=None, MASK_ratio=0.15):
        ### lower_case=Tru
        ### wtf, default to be lower-case !!! so the pre-train data is error ?
        if bert_folder is not None:
            self.tokenizer = BertTokenizer.from_pretrained( bert_folder )
        else:
            if lower_case:
                self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            else:
                self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
        self.lower_case = lower_case
        self.cls_idx = self.tokenizer.cls_token_id
        self.sep_idx = self.tokenizer.sep_token_id
        self.pad_idx = self.tokenizer.pad_token_id
        self.mask_idx = self.tokenizer.mask_token_id
        self.max_word_len = 5
        
        self.vocab_size = self.tokenizer.vocab_size
        self.max_predictions_words = 5
        self.max_predictions_bpes = 20
        self.masked_lm_prob = MASK_ratio
        

    def __call__(self, sents,  ):
        """
           sents: list of string 
           default tokenization
        """
        if self.lower_case:
            out_sents = []
            for sent in sents:
                sent = sent.lower().replace("[mask]", "[MASK]").replace("[sep]", "[SEP]").replace("[cls]", "[CLS]")
                out_sents.append( sent )
                sents = out_sents[:]
        out_info = self.tokenizer( sents, return_tensors="pt", padding=True )
        return out_info 
        #return self.bert_tokenization(word_lst, return_flat = True )

    def bert_tokenization(self, word_lst, target_word=-1, return_flat = True ):

        ## lower case 
        if self.lower_case:
            word_lst = [w.lower() if w not in ["[MASK]","[SEP]", "[CLS]"] else w for w in word_lst  ]           
 
        input_ids, token_type_ids = [ [self.cls_idx]] , [[0]]
        target_word_range = []
        for idx, word in enumerate( word_lst ):
            word_ids = self.word_encoder( word )
            input_ids.append( word_ids )
            token_type_ids.append( [0] * len(word_ids) )
            if idx == target_word:
                end_ids = sum([len(sub) for sub in input_ids] )
                start_ids = end_ids - len(input_ids[-1]) 
                target_word_range = list( range(start_ids, end_ids) )

        ## add SEP token
        input_ids.append( [self.sep_idx] )
        token_type_ids.append( [0] )

        if not return_flat:
            out_info = {"input_ids": input_ids,  "token_type_ids": token_type_ids, } 
            if target_word >= 0: out_info['target_word_offsets'] = target_word_range 
            return out_info 

        input_ids_flat = []
        token_type_ids_flat = [] 
        for ids, type_ids in zip(input_ids, token_type_ids ):
            input_ids_flat.extend( ids[:self.max_word_len] )
            token_type_ids_flat.extend( type_ids[:self.max_word_len]) 

        #out_info = {"input_ids": input_ids,  "offsets": token_type_ids, }
        out_info = {"input_ids": input_ids_flat,  "offsets": token_type_ids_flat, }  
        if target_word >= 0: out_info['target_word_offset'] = target_word_range 
        return out_info

    def generate_mask_token(self, word_ids, ):
        if random.random() < 0.8:
            return self.mask_idx
        else:
            if random.random() < 0.5:
                return word_ids 
            else:
                rd_id = random.randint(0, self.vocab_size-1 ) 
                return rd_id
        return 

    def tokenization_with_MASK(self, word_lst, target_word=-1 ):
        out_info = self.bert_tokenization(word_lst, target_word=target_word, return_flat = False )
        tokens_len = len(out_info['input_ids']) - 2
        num_to_predict = min(self.max_predictions_words,
                            max(1, int(round(tokens_len * self.masked_lm_prob))))

        word_idx = list(range(1, tokens_len+1))
        random.shuffle( word_idx ) 
        mask_idx = set( word_idx[:num_to_predict] )
        
        mask_offsets = []
        input_ids = []
        label_ids = []
 
        for idx, word_ids in enumerate( out_info['input_ids'] ):
            if idx in mask_idx and len(mask_offsets) + len(word_ids) <= self.max_predictions_bpes:
                for sub_ids in word_ids:
                    mask_ids =  self.generate_mask_token( sub_ids )
                    mask_offsets.append( len(input_ids) )
                    input_ids.append( mask_ids )
                    label_ids.append( sub_ids ) 
            else:
                input_ids.extend( word_ids )

        out_info["input_ids"] = input_ids
        out_info["mask_labels"] = label_ids
        out_info["mask_offsets"] = mask_offsets 
       
        return out_info
         
                 
    def word_encoder(self, word):
        ## good -> list
        return  self.tokenizer.encode(word, add_special_tokens=False )

    def ids_decoder(self, batch_ids):
        ## [[101]] -> ['[CLS]']
        return self.tokenizer.batch_decode( batch_ids )




if __name__ == "__main__":

    example = {"answer": "Abandon (a person, cause, or organization) in a way considered disloyal or treacherous.", 
                "candidates": ["A dry, barren area of land, especially one covered with sand, that is characteristically desolate, waterless, and without vegetation.", "A commotion; a fuss."], 
                "word": "desert", 
                "sent": "we feel our public representatives have deserted us",
                "word_idx": 6,
                }
    
    tokenizer = Tokenization(lower_case=True )
    out_info = tokenizer.tokenization_with_MASK( example['sent'].strip().split(' '), target_word=example['word_idx'] )
    
    print(out_info)
    tokenizer.ids_decoder( [out_info['input_ids']])

    defs = [example['answer']] + example['candidates']
    sent_info = tokenizer(defs)
    print(sent_info)
    import pdb; pdb.set_trace()