# coding: utf-8

import os
import json
import pickle
import torch
import constant as C
import numpy as np
from torch_geometric.data import Data, InMemoryDataset
from allennlp.modules.elmo import batch_to_ids

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class GraphDataset(InMemoryDataset):
    def __init__(self, root, sample_n={'sent':100},\
                graph_info_fn='tmp/graph_info.pkl', \
                transform=None, pre_transform=None, partial_n=None, interaction_keep_prob=None):
        self.interaction_keep_prob = interaction_keep_prob

        self.partial_n = partial_n
        self.sample_n = sample_n

        self.graph_info_fn = graph_info_fn
        super(GraphDataset, self).__init__(root, transform, pre_transform)
        # load from file if the dataset is already made
        
        with open(self.processed_paths[0], 'rb') as f:
            self.graph_info = pickle.load(f)

        self.graph_info['sent_adj'] = self.graph_info['sent_adj'].tocsr()
        self.graph_info['interaction'] = self.graph_info['interaction'].tocsr()
        # IMPORTANT: remap the indices, so that the indices of labels are 0 -> #labels, 
        # and the indices of samples are #labels -> #labels+#samples
        self.n_lbl = self.graph_info['lbl_adj'].shape[0]
        self.n_sent = self.graph_info['sent_adj'].shape[0]

        self.sent_idx_bias = self.n_lbl

        # del self.graph_info
        
        # with open(self.processed_paths[1], 'rb') as f:
        #     self.data = pickle.load(f)
    
    def __len__(self):
        if self.partial_n: return self.partial_n
        return self.n_sent

    # def get(self, idx):
    #     return self.data[idx]

    def get(self, idx):
        ''' Reimplement this method to sample subgraph dynamically.
        Args:
            idx: index to the target sample
        Returns:
            A torch_geometric.data.Data object, which contains the sampled subgraph
            and corresponding theshold for node types.
        '''
        target_idx = idx

        sent_info = self.graph_info['sent_adj'].getrow(idx).tocoo()
        sent_from = sent_info.col
        # # re-sampling
        # top k neighbors with highest tfidf similarity
        if sent_from.size > self.sample_n['sent']:
            sent_tfidf_similarity = sent_info.data
            sent_inds = np.argpartition(sent_tfidf_similarity, -self.sample_n['sent'])[-self.sample_n['sent']: ]
            sent_from = sent_from[sent_inds]

        sent_to = np.empty_like(sent_from)
        sent_to.fill(idx)

        # target sent's interactions with labels
        interaction_info = self.graph_info['interaction'].getrow(idx).tocoo()
        inter_from = interaction_info.col
        inter_to = np.empty_like(inter_from)
        inter_to.fill(idx)

        # target_lbls = list(inter_to)
        target_lbls = list(inter_from)

        # add index bias to sentence nodes
        sent_from += self.sent_idx_bias
        sent_to += self.sent_idx_bias
        inter_to += self.sent_idx_bias
        target_idx += self.sent_idx_bias

        # keep interactions with prob
        if self.interaction_keep_prob and sent_to.size > 0:
            keep = np.random.binomial(1, self.interaction_keep_prob)
            if not keep:
                inter_from = []
                inter_to = []

        # concatenate, and NO reverse edges
        from_final = np.concatenate([sent_from, inter_from])
        to_final = np.concatenate([sent_to, inter_to])

        # is_label masks
        from_mask = from_final < self.sent_idx_bias
        to_mask = to_final < self.sent_idx_bias

        # remap index to 0-n
        inds = list(set(from_final).union(set(to_final)))

        inds_map = {j:i for i, j in enumerate(inds)}
        from_final = [inds_map[idx] for idx in list(from_final)]
        to_final = [inds_map[idx] for idx in list(to_final)]

        # get tensors
        edge_index = torch.tensor([from_final, to_final], \
                    device=device, dtype=torch.long)

        embeds_index = torch.tensor(inds, dtype=torch.long, device=device)

        from_mask = torch.tensor(from_mask, device=device)
        to_mask = torch.tensor(to_mask, device=device)

        return Data(edge_index=edge_index, embeds_index=embeds_index, \
            from_mask=from_mask, to_mask=to_mask, num_nodes=len(inds),\
                target_idx=target_idx, target_lbls=target_lbls)


        
    def sample_inds(self, from_len, sample_num):
        replace = False if from_len >= sample_num else True
        sample_inds = np.random.choice(np.arange(from_len), sample_num, replace)
        return sample_inds

    @property
    def raw_file_names(self):
        return []

    @property
    def processed_file_names(self):
        # return ['processed_dataset.pth', 'train_batches.pth']
        return ['processed_dataset.pth']
    
    def download(self):
        pass

    def process(self):
        ''' Simply copy the graph info to processed dir.
        '''

        # Load preprocessed graph info (in scipy.sparse)
        with open(self.graph_info_fn, 'rb') as f:
            graph_info = pickle.load(f)

        self.sent_idx_bias = self.n_lbl = graph_info['lbl_adj'].shape[0]
        self.n_sent = graph_info['sent_adj'].shape[0]

        # optimize the loading speed
        graph_info['sent_adj'] = graph_info['sent_adj'].tocsr()
        graph_info['interaction'] = graph_info['interaction'].tocsr()

        # batch_list = []
        # for idx in range(len(self)):
        #     print('Preprocessing: {}'.format(idx), end='\r')
        #     batch_list.append(self.get_new(idx, graph_info))
        # print()

        with open(self.processed_paths[0], 'wb') as f:
            pickle.dump(graph_info, f)

        # with open(self.processed_paths[1], 'wb') as f:
        #     pickle.dump(batch_list, f)
        


class LabelGraphDataset(InMemoryDataset):
    def __init__(self, root, sample_n={'neighbor_sent': 10, 'lbl': 100},\
                graph_info_fn='tmp/graph_info.pkl', \
                transform=None, pre_transform=None, partial_n=None, interaction_keep_prob=None):
        self.interaction_keep_prob = interaction_keep_prob

        self.partial_n = partial_n
        self.sample_n = sample_n

        self.graph_info_fn = graph_info_fn
        super(LabelGraphDataset, self).__init__(root, transform, pre_transform)
        # load from file if the dataset is already made
        
        with open(self.processed_paths[0], 'rb') as f:
            self.graph_info = pickle.load(f)
        # IMPORTANT: remap the indices, so that the indices of labels are 0 -> #labels, 
        # and the indices of samples are #labels -> #labels+#samples
        self.n_lbl = self.graph_info['lbl_adj'].shape[0]
        self.n_sent = self.graph_info['sent_adj'].shape[0]

        self.sent_idx_bias = self.n_lbl


        # optimize loading speed
        self.graph_info['interaction'] = self.graph_info['interaction'].tocsc()
        self.graph_info['lbl_adj'] = self.graph_info['lbl_adj'].tocsr()
    
    def __len__(self):
        if self.partial_n: return self.partial_n
        return self.n_lbl

    def get(self, idx):
        ''' Reimplement this method to sample subgraph dynamically.
        Args:
            idx: index to the target label
        Returns:
            A torch_geometric.data.Data object, which contains the sampled subgraph
            and corresponding theshold for node types.
        '''
        target_idx = idx

        # target lbl's interaction with sents (inter_from is sents, inter_to is the target label)
        
        interaction_info = self.graph_info['interaction'].getcol(idx).tocoo()
        inter_from = interaction_info.row

        # target sent's labels' global interactions 
        lbl_info = self.graph_info['lbl_adj'].getrow(idx).tocoo()
        lbl_from = lbl_info.col

        # re-sampling (for neighboring sents of the target label)
        inter_inds = self.sample_inds(inter_from.size, self.sample_n['neighbor_sent'])
        inter_from = inter_from[inter_inds]
        
        inter_to = np.empty_like(inter_from)
        inter_to.fill(idx)

        if lbl_from.size > 0:
            lbl_inds = self.sample_inds(lbl_from.size, self.sample_n['lbl'])
            lbl_from = lbl_from[lbl_inds]
        
        lbl_to = np.empty_like(lbl_from)
        lbl_to.fill(idx)

        # add index bias to sentence nodes
        inter_from += self.sent_idx_bias

        # keep interactions with prob
        if self.interaction_keep_prob:
            keep = np.random.binomial(1, self.interaction_keep_prob)
            if not keep:
                inter_from = []
                inter_to = []

        # concatenate, and NO reverse edges
        from_final = np.concatenate([inter_from, lbl_from])
        to_final = np.concatenate([inter_to, lbl_to])

        # is_label masks
        from_mask = from_final < self.sent_idx_bias
        to_mask = to_final < self.sent_idx_bias

        # remap index to 0-n
        inds = list(set(from_final).union(set(to_final)))

        inds_map = {j:i for i, j in enumerate(inds)}
        from_final = [inds_map[idx] for idx in list(from_final)]
        to_final = [inds_map[idx] for idx in list(to_final)]

        # get tensors
        edge_index = torch.tensor([from_final, to_final], \
                    device=device, dtype=torch.long)

        embeds_index = torch.tensor(inds, dtype=torch.long, device=device)

        from_mask = torch.tensor(from_mask, device=device)
        to_mask = torch.tensor(to_mask, device=device)

        return Data(edge_index=edge_index, embeds_index=embeds_index, \
            from_mask=from_mask, to_mask=to_mask, num_nodes=len(inds),\
                target_idx=target_idx, target_lbls=inter_from)   # 借用target_lbls的槽位传label相关的句子的indices

        
    def sample_inds(self, from_len, sample_num):
        replace = False if from_len >= sample_num else True
        sample_inds = np.random.choice(np.arange(from_len), sample_num, replace)
        return sample_inds

    @property
    def raw_file_names(self):
        return []

    @property
    def processed_file_names(self):
        return ['processed_dataset.pth']
    
    def download(self):
        pass

    def process(self):
        ''' Simply copy the graph info to processed dir.
        '''

        # Load preprocessed graph info (in scipy.sparse)
        with open(self.graph_info_fn, 'rb') as f:
            graph_info = pickle.load(f)

        # optimize loading speed
        self.graph_info['interaction'] = self.graph_info['interaction'].tocsc()
        self.graph_info['lbl_adj'] = self.graph_info['lbl_adj'].tocsr()

        with open(self.processed_paths[0], 'wb') as f:
            pickle.dump(graph_info, f)



class DataLoader(torch.utils.data.DataLoader):
    def __init__(self, dataset, batch_size=1, shuffle=False):
        super(DataLoader, self).__init__(dataset, batch_size, shuffle, collate_fn=self.collate_func)

    def collate_func(self, batch):
        ''' Receive a list of input Data objects, output the collated Data object.
        '''
        accumulated_num_nodes = 0
        edge_index_list = []
        embeds_index_list = []
        from_mask_list = []
        to_mask_list = []
        target_idx_list = []
        target_lbls_list = []
        for data in batch:
            edge_index_list.append(data.edge_index + accumulated_num_nodes)

            embeds_index_list.append(data.embeds_index)
            from_mask_list.append(data.from_mask)
            to_mask_list.append(data.to_mask)

            target_idx_list.append(data.target_idx)
            target_lbls_list.append(data.target_lbls)

            accumulated_num_nodes += data.num_nodes

        edge_index = torch.cat(edge_index_list, dim=1)
        embeds_index = torch.cat(embeds_index_list, dim=0)
        from_mask = torch.cat(from_mask_list, dim=0)
        to_mask = torch.cat(to_mask_list, dim=0)
        return Data(edge_index=edge_index, embeds_index=embeds_index, \
            from_mask=from_mask, to_mask=to_mask, num_nodes=accumulated_num_nodes,\
                target_idx=target_idx_list, target_lbls=target_lbls_list)




class TestGraphDataset_on_trainset(InMemoryDataset):
    def __init__(self, root, sample_n={'sent':100},\
                graph_info_fn='tmp/graph_info.pkl', \
                transform=None, pre_transform=None):

        self.sample_n = sample_n

        self.graph_info_fn = graph_info_fn
        super(TestGraphDataset_on_trainset, self).__init__(root, transform, pre_transform)
        # load from file if the dataset is already made
        
        with open(self.processed_paths[0], 'rb') as f:
            self.graph_info = pickle.load(f)
        self.graph_info['sent_adj'] = self.graph_info['sent_adj'].tocsr()
        self.graph_info['interaction'] = self.graph_info['interaction'].tocsr()
        # IMPORTANT: remap the indices, so that the indices of labels are 0 -> #labels, 
        # and the indices of samples are #labels -> #labels+#samples
        self.n_lbl = self.graph_info['lbl_adj'].shape[0]
        self.n_sent = self.graph_info['sent_adj'].shape[0]

        self.sent_idx_bias = self.n_lbl
    
    def __len__(self):
        return self.n_sent

    def get(self, idx):
        ''' Reimplement this method to sample subgraph dynamically. 
        During testing process, ignore the edges involving labels.
                   
        Args:
            idx: index to the target sample
        Returns:
            A torch_geometric.data.Data object, which contains the sampled subgraph
            and corresponding theshold for node types.
        '''
        target_idx = idx

        sent_info = self.graph_info['sent_adj'].getrow(idx).tocoo()
        sent_tfidf_similarity = sent_info.data

        sent_from = sent_info.col
        sent_to = np.empty_like(sent_from)
        sent_to.fill(idx)

        # target sent's interactions with labels
        interaction_info = self.graph_info['interaction'].getrow(idx).tocoo()
        target_labels = list(interaction_info.col)

        # re-sampling
        
        # top k neighbors with highest tfidf similarity
        if sent_from.size > self.sample_n['sent']:
            sent_inds = np.argpartition(sent_tfidf_similarity, -self.sample_n['sent'])[-self.sample_n['sent']: ]
            sent_from, sent_to = sent_from[sent_inds], sent_to[sent_inds]

        # add index bias to sentence nodes
        sent_from += self.sent_idx_bias
        sent_to += self.sent_idx_bias
        target_idx += self.sent_idx_bias

        # concatenate, and NO REVERSE EDGES
        from_final = sent_from
        to_final = sent_to

        # masks
        from_mask = from_final < self.sent_idx_bias
        to_mask = to_final < self.sent_idx_bias

        # remap index to 0-n
        inds = list(set(from_final).union(set(to_final)))

        inds_map = {j:i for i, j in enumerate(inds)}
        from_final = [inds_map[idx] for idx in list(from_final)]
        to_final = [inds_map[idx] for idx in list(to_final)]

        # get tensors
        edge_index = torch.tensor([from_final, to_final], \
                    device=device, dtype=torch.long)

        embeds_index = torch.tensor(inds, dtype=torch.long, device=device)

        from_mask = torch.tensor(from_mask, device=device)
        to_mask = torch.tensor(to_mask, device=device)

        return Data(edge_index=edge_index, embeds_index=embeds_index, \
            from_mask=from_mask, to_mask=to_mask, num_nodes=len(inds),\
                target_idx=target_idx, target_lbls=target_labels)

        
    def sample_inds(self, from_len, sample_num):
        replace = False if from_len >= sample_num else True
        sample_inds = np.random.choice(np.arange(from_len), sample_num, replace)
        return sample_inds

    @property
    def raw_file_names(self):
        return []

    @property
    def processed_file_names(self):
        return ['processed_dataset.pth']
    
    def download(self):
        pass

    def process(self):
        ''' Simply copy the graph info to processed dir.
        '''

        # Load preprocessed graph info (in scipy.sparse)
        with open(self.graph_info_fn, 'rb') as f:
            graph_info = pickle.load(f)
        graph_info['sent_adj'] = graph_info['sent_adj'].tocsr()
        graph_info['interaction'] = graph_info['interaction'].tocsr()
        with open(self.processed_paths[0], 'wb') as f:
            pickle.dump(graph_info, f)



class TestGraphDataset(InMemoryDataset):
    def __init__(self, root, sample_n={'sent':100},\
                graph_info_fn='tmp/graph_info_test.pkl', \
                transform=None, pre_transform=None):

        self.sample_n = sample_n

        self.graph_info_fn = graph_info_fn
        super(TestGraphDataset, self).__init__(root, transform, pre_transform)
        # load from file if the dataset is already made
        
        with open(self.processed_paths[0], 'rb') as f:
            self.graph_info = pickle.load(f)
                # optimize the loading speed
        self.graph_info['sent_adj'] = self.graph_info['sent_adj'].tocsr()
        self.graph_info['interaction'] = self.graph_info['interaction'].tocsr()
        # IMPORTANT: remap the indices, so that the indices of labels are 0 -> #labels, 
        # and the indices of samples are #labels -> #labels+#samples

        self.n_lbl = self.graph_info['lbl_adj'].shape[0]
        self.n_sent = self.graph_info['sent_adj'].shape[0]

        self.n_test = self.graph_info['n_test']

        self.sent_idx_bias = self.n_lbl

        # del self.graph_info
        
        # with open(self.processed_paths[1], 'rb') as f:
        #     self.data = pickle.load(f)
    
    def __len__(self):
        return self.n_test

    # def get(self, idx):
    #     return self.data[idx]

    def get(self, idx):
        ''' Reimplement this method to sample subgraph dynamically. 
        During testing process, ignore the edges involving labels.
                   
        Args:
            idx: index to the target sample
        Returns:
            A torch_geometric.data.Data object, which contains the sampled subgraph
            and corresponding theshold for node types.
        '''
        target_idx = idx

        sent_info = self.graph_info['sent_adj'].getrow(idx).tocoo()
        sent_tfidf_similarity = sent_info.data

        sent_from = sent_info.col
        sent_to = np.empty_like(sent_from)
        sent_to.fill(idx)

        # target sent's interactions with labels
        interaction_info = self.graph_info['interaction'].getrow(idx).tocoo()
        target_labels = list(interaction_info.col)

        # re-sampling
        
        # top k neighbors with highest tfidf similarity
        if sent_from.size > self.sample_n['sent']:
            sent_inds = np.argpartition(sent_tfidf_similarity, -self.sample_n['sent'])[-self.sample_n['sent']: ]
            sent_from, sent_to = sent_from[sent_inds], sent_to[sent_inds]


        # add index bias to sentence nodes
        sent_from += self.sent_idx_bias
        sent_to += self.sent_idx_bias
        target_idx += self.sent_idx_bias

        # concatenate, and no reverse edges
        from_final = sent_from
        to_final = sent_to

        # masks
        from_mask = from_final < self.sent_idx_bias
        to_mask = to_final < self.sent_idx_bias

        # remap index to 0-n
        inds = list(set(from_final).union(set(to_final)))

        inds_map = {j:i for i, j in enumerate(inds)}
        from_final = [inds_map[idx] for idx in list(from_final)]
        to_final = [inds_map[idx] for idx in list(to_final)]

        # get tensors
        edge_index = torch.tensor([from_final, to_final], \
                    device=device, dtype=torch.long)

        embeds_index = torch.tensor(inds, dtype=torch.long, device=device)

        from_mask = torch.tensor(from_mask, device=device)
        to_mask = torch.tensor(to_mask, device=device)

        return Data(edge_index=edge_index, embeds_index=embeds_index, \
            from_mask=from_mask, to_mask=to_mask, num_nodes=len(inds),\
                target_idx=target_idx, target_lbls=target_labels)

        
    def sample_inds(self, from_len, sample_num):
        replace = False if from_len >= sample_num else True
        sample_inds = np.random.choice(np.arange(from_len), sample_num, replace)
        return sample_inds

    @property
    def raw_file_names(self):
        return []

    @property
    def processed_file_names(self):
        return ['processed_dataset_test.pth', 'test_batches.pth']
    
    def download(self):
        pass

    def process(self):
        ''' Simply copy the graph info to processed dir.
        '''

        # Load preprocessed graph info (in scipy.sparse)
        with open(self.graph_info_fn, 'rb') as f:
            graph_info = pickle.load(f)

        self.sent_idx_bias = self.n_lbl = graph_info['lbl_adj'].shape[0]
        self.n_sent = graph_info['sent_adj'].shape[0]
        self.n_test = graph_info['n_test']

        # optimize the loading speed
        graph_info['sent_adj'] = graph_info['sent_adj'].tocsr()
        graph_info['interaction'] = graph_info['interaction'].tocsr()

        # batch_list = []
        # for idx in range(len(self)):
        #     print('Preprocessing: {}'.format(idx), end='\r')
        #     batch_list.append(self.get_new(idx, graph_info))
        # print()

        with open(self.processed_paths[0], 'wb') as f:
            pickle.dump(graph_info, f)

        # with open(self.processed_paths[1], 'wb') as f:
        #     pickle.dump(batch_list, f)


class SentenceDataset(torch.utils.data.Dataset):
    def __init__(self, sent_and_pos):
        super(SentenceDataset, self).__init__()
        if isinstance(sent_and_pos, str):
            ''' Train case'''
            with open(sent_and_pos, 'rb') as f:
                data, _, _ = pickle.load(f)
            sents = [item[0] for item in data]
            pos = [item[1] for item in data]
        elif isinstance(sent_and_pos, dict):
            ''' If a dictionary is input (Test case), order the test samples before the train samples. '''
            with open(sent_and_pos['test'], 'rb') as f:
                data, _, _ = pickle.load(f)
            sents = [item[0] for item in data]
            pos = [item[1] for item in data]

            with open(sent_and_pos['train'], 'rb') as f:
                data, _, _ = pickle.load(f)
            sents += [item[0] for item in data]
            pos += [item[1] for item in data]

        else:
            sents, pos = sent_and_pos
        
        self.sents = sents
        self.mention_pos = pos
    
    def __len__(self):
        return len(self.sents)
    
    def __getitem__(self, idx):
        if isinstance(idx, (int, slice)):
            return self.sents[idx], self.mention_pos[idx]
        elif isinstance(idx, (list, tuple)):
            pass
        elif torch.is_tensor(idx):
            idx = list(idx.cpu())

        retr_sents = []
        retr_pos = []
        for id_ in idx:
            retr_sents.append(self.sents[id_])
            retr_pos.append(self.mention_pos[id_])
        return retr_sents, retr_pos
        


class HFetSentenceDataset(torch.utils.data.Dataset):
    def __init__(self, fn, lbl2id_fn='data/ontology/onto_ontology.txt', tmp_fn='hfetsents.pkl'):
        ' Initialize the dataset by preprocessing into the desired format '
        super(HFetSentenceDataset, self).__init__()

        with open(lbl2id_fn) as f:
            self.lbl2id = {j:i for i,j in enumerate(f.read().strip().split('\n'))}

        if not isinstance(fn, str): fnstr = ''.join(list(fn))
        else: fnstr = fn
        tmp_fn = os.path.join('tmp', fnstr.replace('/', '')+tmp_fn)
        # if already processed, load from tmp cache file
        if not os.path.exists('tmp'): os.mkdir('tmp')
        if os.path.exists(tmp_fn): 
            with open(tmp_fn, 'rb') as f:
                sentset = pickle.load(f)
            self.data = sentset
            print('loaded')
        else:
            # otherwise, load raw file and process
            if isinstance(fn, str):
                ''' Train case'''
                with open(fn) as f:
                    lines = f.read().strip().split('\n')

            elif isinstance(fn, dict):
                ''' If a dictionary is input (Test case), order the test samples before the train samples. '''
                with open(fn['test']) as f:
                    lines = f.read().strip().split('\n')

                with open(fn['train']) as f:
                    lines.extend(f.read().strip().split('\n'))
            elif isinstance(fn, list):
                lines = fn
            
            self.data = []
            for i, line in enumerate(lines):
                print(i, end='\r')
                newdata = HFetSentenceDataset.processLine(line)
                if newdata:
                    self.data.append(newdata)
                else:
                    raise AssertionError
            print('Finished sentence preprocessing. Saving to file...')
            # save
            with open(tmp_fn, 'wb') as f:
                pickle.dump(self.data, f)
            print('Done')

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        def _mask_to_distance(mask, mask_len, decay=.1):
            start = mask.index(1)
            end = mask_len - list(reversed(mask)).index(1)
            dist = [0] * mask_len
            for i in range(start):
                dist[i] = max(0, 1 - (start - i - 1) * decay)
            for i in range(end, mask_len):
                dist[i] = max(0, 1 - (i - end) * decay)
            return dist

        retr = []
        if isinstance(idx, int):
            retr.append(HFetSentenceDataset.numberize(self.data[idx], self.lbl2id))
        elif isinstance(idx, slice):
            retr.extend([HFetSentenceDataset.numberize(i, self.lbl2id) for i in self.data[idx]])
        elif isinstance(idx, (list, tuple)):
            pass
        elif torch.is_tensor(idx):
            idx = list(idx.cpu())

        if len(retr) == 0:
            for id_ in idx:
                retr.append(HFetSentenceDataset.numberize(self.data[id_], self.lbl2id))

        max_seq_len = max([x[-1] for x in retr])

        batch_char_ids = []
        batch_labels = []
        batch_men_mask = []
        batch_dist = []
        batch_ctx_mask = []
        batch_gathers = []
        batch_men_ids = []
        for inst_idx, inst in enumerate(retr):
            char_ids, labels, men_mask, ctx_mask, men_ids, anno_num, seq_len = inst
            batch_char_ids.append(char_ids + [[C.PAD_INDEX] * C.ELMO_MAX_CHAR_LEN
                                              for _ in range(max_seq_len - seq_len)])
            for ls in labels:
                batch_labels.append([1 if l in ls else 0
                                     for l in range(len(self.lbl2id))])
            for mask in men_mask:
                batch_men_mask.append(mask + [C.PAD_INDEX] * (max_seq_len - seq_len))
                batch_dist.append(_mask_to_distance(mask, seq_len)
                                  + [C.PAD_INDEX] * (max_seq_len - seq_len))
            for mask in ctx_mask:
                batch_ctx_mask.append(mask + [C.PAD_INDEX] * (max_seq_len - seq_len))
            batch_gathers.extend([inst_idx] * anno_num)
            batch_men_ids.extend(men_ids)

        batch_char_ids = torch.tensor(batch_char_ids, dtype=torch.long, device=device)
        batch_labels = torch.tensor(batch_labels, dtype=torch.float, device=device)
        batch_men_mask = torch.tensor(batch_men_mask, dtype=torch.float, device=device)
        batch_ctx_mask = torch.tensor(batch_ctx_mask, dtype=torch.float, device=device)
        batch_gathers = torch.tensor(batch_gathers, dtype=torch.long, device=device)
        batch_dist = torch.tensor(batch_dist, dtype=torch.float, device=device)

        # return (batch_char_ids, batch_labels, batch_men_mask, batch_ctx_mask,
        #         batch_dist, batch_gathers, batch_men_ids)
        return (batch_char_ids, batch_labels, batch_men_mask, batch_ctx_mask,
                batch_dist, batch_gathers)

    @staticmethod
    def processLine(line, mention_id=0):
        data_dict = json.loads(line)
        if isinstance(data_dict, dict):
            if 'left_context_token' in data_dict:
                left_context_token = data_dict['left_context_token']
            else:
                return
                
            if 'right_context_token' in data_dict:
                right_context_token = data_dict['right_context_token']
            else:
                return

            if 'mention_span' in data_dict:
                mention_span = data_dict['mention_span']
            else:
                return

            if 'y_str' in data_dict:
                y_str = data_dict['y_str']
            else:
                return

            mention_tokens = mention_span.split(' ')
            data = {}
            data['tokens'] = left_context_token + mention_tokens + right_context_token
            ann = [{'mention_id': str(mention_id), 'mention':mention_span,  \
                    'start': len(left_context_token), 'end': len(left_context_token)+len(mention_tokens),\
                    'labels': y_str}]
            data['annotations'] = ann
        return data

    @staticmethod
    def numberize(inst, label_stoi):
        tokens = inst['tokens']
        tokens = [C.TOK_REPLACEMENT.get(t, t) for t in tokens]
        seq_len = len(tokens)
        char_ids = batch_to_ids([tokens])[0].tolist()
        labels_nbz, men_mask, ctx_mask, men_ids = [], [], [], []
        annotations = inst['annotations']
        anno_num = len(annotations)
        for annotation in annotations:
            mention_id = annotation['mention_id']
            labels = annotation['labels']
            labels = [l.replace('geograpy', 'geography') for l in labels]
            start = annotation['start']
            end = annotation['end']

            men_ids.append(mention_id)
            labels = [label_stoi[l] for l in labels if l in label_stoi]
            labels_nbz.append(labels)
            men_mask.append([1 if i >= start and i < end else 0
                             for i in range(seq_len)])
            ctx_mask.append([1 if i < start or i >= end else 0
                             for i in range(seq_len)])
        return (char_ids, labels_nbz, men_mask, ctx_mask, men_ids, anno_num,
                seq_len)



if __name__ == '__main__':

    # sentset = SentenceDataset('tmp/prep.pkl')
    # print(sentset[[0, 2, 3]])
    # print(sentset[:10])



    # graphdataset = GraphDataset('.', graph_info_fn='tmp/graph_info.pkl')

    # # from torch_geometric.data import DataLoader
    # loader = DataLoader(graphdataset, batch_size=32, shuffle=False)
    # for i, datapack in enumerate(loader):
    #     print(i, end='\r')
    #     print(datapack.target_idx, datapack.target_lbls)



    sentset = HFetSentenceDataset(fn='data/ontonotes/g_train.json')
    # testset = HFetSentenceDataset(fn={'train': 'data/ontonotes/g_train.json', 'test': 'data/ontonotes/g_test.json'})
    print(sentset[0:100])