import os 


"""Tweaked AllenNLP dataset reader."""
import logging
import re
from random import random
from typing import Dict, List
import gc
from overrides import overrides
from torch.utils import data
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import json
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence

import pickle
import random
 

class TextDataset(data.Dataset):
    """
        data is stored in list,
        We tend to tokenize the data in pre-process, and directly use the BERT ids here
        if the input is text_string, should run Tokenization in tokenization.py
    """
    
    def __init__(self,  json_fp = None,
                        pick_fp = None, 
                        json_list_fp = None,
                        max_len = 128,
                         ):
        if pick_fp is not None:
            dataset =  pickle.load(open(pick_fp, "rb"))
        elif json_fp is not None:
            dataset = json.loads( open(json_fp,"r") )
        elif json_list_fp is not None:
            dataset = [json.loads(line.strip()) for line in open(json_list_fp,"r").readlines()]
        else:
            assert False 

        keys = ['sent_input_ids']
        self.dataset = self.data_filter( dataset, keys, max_len)
        print( "data size: ",  len( self.dataset )  )

    def data_filter(self, data, keys, max_len):
        out_data = []
        for info in data:
            filter_info = any( [ len(info[k]) >= max_len  for k in keys ] )
            if not filter_info:
                out_data.append( info )
        print("filter_data size: ", len(data) - len(out_data))
        return out_data 

    def __len__(self):
        return len( self.dataset )
    def __getitem__(self, idx):
        return self.dataset[idx]


"""
def Padding(array_index, seq_len, PAD=0):
    out = np.pad(array_index,(0, max(seq_len-len(array_index),0) ),'constant',constant_values=(0,PAD))
    lst + [PAD] * (seq_len - len(lst) )
    return out 
"""

def pad_list( lst, seq_len, pad_ids=0):
    if len(lst) <= seq_len: 
        lst.extend( [pad_ids] * (seq_len - len(lst)) ) 
        return lst
    return lst[:seq_len]


"""
    target input format:
        sent_input_ids, ## (bs, seq_len)
        target_word_offsets, ## (bs)
        def_input_ids,
            ## (batch_size, num_sent, seq_len)
            ## num_sent_id = 0: answer; num_sent_id = 1: hard_example;
        mlm_labels: (bs, 15),
        mlm_offsets: (bs, 15), 
"""


class MyCollator(object):
    """
        for batch of list data, pad them dynamicly, and return pt tensor in format
    """
    def __init__(self,  PAD_ids=0, maxlen=256, mask_clean_input=0, use_roberta=False ):
        self.PAD_ids = PAD_ids
        self.maxlen = maxlen 

        self.max_mlm_len = 15      
        self.max_sent_len = 128
        ## should filter sent len in reading, other wise, offset may meet out-of-range problem
        
        self.mask_clean_ratio=0.1
            # how many words in the clean text are masked
        self.max_clean_mask_count = 8

        if use_roberta:
            self.mask_id = 50264
            self.PAD_ids = 1
        else:
            self.mask_id = 103
            self.PAD_ids = 0

    def __call__(self, batch ):
        #### should be re-written for each task, hard code here 
        max_sent_len, max_def_len = 0, 0
        max_mlm_len = self.max_mlm_len
        for info in batch:
            max_sent_len = max(max_sent_len, len(info['sent_input_ids']) )    
            max_def_len = max(max_def_len, len(info['def_input_ids'][0]), len(info['def_input_ids'][1]) )
        max_sent_len = min(max_sent_len, self.max_sent_len)
        max_def_len = min(max_def_len, self.max_sent_len)

        out_batch = dict() 
        out_batch['mlm_input_ids'] = [pad_list( info['sent_input_ids'][:], max_sent_len, self.PAD_ids ) for info in batch ] 
        out_batch['target_word_offsets'] = [ info['target_word_offsets'][0]  for info in batch ]
        out_batch['def_input_ids'] = [ [pad_list(info['def_input_ids'][0], max_def_len,self.PAD_ids), pad_list(info['def_input_ids'][1], max_def_len,self.PAD_ids) ] for info in batch ]                                         
        out_batch['mlm_labels'] = [pad_list( info['mlm_labels'][:], max_mlm_len, self.PAD_ids ) for info in batch ]  
        out_batch['mlm_offsets'] = [pad_list( info['mlm_offsets'][:], max_mlm_len, self.PAD_ids ) for info in batch ]  

        batch_clean_input_ids = []
        for info in batch:
            clean_input_ids = info['sent_input_ids'][:]
            for mlm_offset, mlm_label in zip(info['mlm_offsets'], info['mlm_labels']):
                clean_input_ids[mlm_offset] = mlm_label
            
            clean_input_mask_offsets = [ i for i in range(len(clean_input_ids)) if random.random() < self.mask_clean_ratio and i != info['target_word_offsets'][0] ]
            for mlm_offset in clean_input_mask_offsets[:self.max_clean_mask_count] :
                clean_input_ids[mlm_offset] = self.mask_id
            
            batch_clean_input_ids.append( pad_list(clean_input_ids, max_sent_len, self.PAD_ids ) ) 
        out_batch['clean_input_ids'] = batch_clean_input_ids 

        for k in out_batch:
            out_batch[k] = torch.tensor(out_batch[k],dtype=torch.long)
        return out_batch
 
               


if __name__=="__main__":
    pass