import sys
import os
import os.path as osp
import traceback
import glob
import argparse
from tqdm import tqdm
import json
from PIL import Image
from abc import ABC, abstractmethod
import multiprocessing as mp
import numpy as np
import pandas as pd
import networkx as nx
import torch
import torch.nn.functional as F
# from torch_geometric.utils.convert import from_networkx
from torch_geometric.data import Data
from torch_geometric.data import Dataset as DT, download_url


from GPT_GNN.config import *
from GPT_GNN.utils import Word, get_pct_offsets, to_graph, to_graph_beta, to_graph_relations


class Dataset(ABC):
    def __init__(self, name, home, img_extension):
        self.name = name
        self.home = home
        self.img_extension = img_extension
        self.train_home = ''
        self.val_home = ''
        self.test_home = ''
        self.train_ocr = ''
        self.train_images = ''
        self.train_graphs = ''
        self.train_alignets = ''
        self.val_ocr = ''
        self.val_images = ''
        self.val_graphs = ''
        self.val_alignets = ''
        self.test_ocr = ''
        self.test_images = ''
        self.test_graphs = ''
        self.test_alignets = ''
        self.num_classes = 0
        self.labels = []
        self.metric = None

    @abstractmethod
    def load_words(self, filepath):
        pass

    @abstractmethod
    def get_label(self, s):
        pass

    def process(self, n_pool, graph_type, task_type, n_comms, hierarchical):
        self.load_training_graphs(self.train_home, self.train_ocr, self.train_images, self.train_graphs, n_pool, graph_type, task_type, n_comms, hierarchical)
        self.load_training_graphs(self.test_home, self.test_ocr, self.test_images, self.test_graphs, n_pool, graph_type, task_type, n_comms, hierarchical)
        self.load_training_graphs(self.val_home, self.val_ocr, self.val_images, self.val_graphs, n_pool, graph_type, task_type, n_comms, hierarchical)

    def to_json(self, g, image_path, file_name):
        j = {'text': [],
            'node_features': [], 
            'node_masks': [], 
            'node_vt_idxs': [], 
            'comms': [],
            'edge_attrs': {},
            'edge_labels': {},
            'edge_weights': {},
            'peers': {},
            'blocks': [],
            'segments': [],
            'y': [], 
            'image_path': image_path, 
            'file_name': file_name}
        for node in g.nodes:
            j['text'].append(g.nodes[node]['text'])
            j['node_features'].append(g.nodes[node]['x'])
            j['node_masks'].append(g.nodes[node]['mask'])
            j['node_vt_idxs'].append(g.nodes[node]['visual_token_idx'])
            j['comms'].append(g.nodes[node]['comm'])
            j['blocks'].append(g.nodes[node]['block'])
            j['segments'].append(g.nodes[node]['segment'])
            j['y'].append(g.nodes[node]['y'])
        for edge in g.edges:
            src, tgt = edge[0], edge[1]
            if src not in j['edge_attrs']: j['edge_attrs'][src] = {}
            if src not in j['edge_labels']: j['edge_labels'][src] = {}
            if src not in j['edge_weights']: j['edge_weights'][src] = {}
            if src not in j['peers']: j['peers'][src] = {}
            j['edge_attrs'][src][tgt] = g[src][tgt]['edge_attr']
            j['edge_labels'][src][tgt] = g[src][tgt]['label']
            j['edge_weights'][src][tgt] = g[src][tgt]['weight']
            j['peers'][src][tgt] = int(g[src][tgt]['peer'])
        return j
    
    def convert_to_pyg(self):
        alignet = AligNet(root=self.train_home, name=self.name)
        alignet.process()
        alignet = AligNet(root=self.val_home, name=self.name)
        alignet.process()
        alignet = AligNet(root=self.test_home, name=self.name)
        alignet.process()
    
    @staticmethod
    def map_y(y):
        if y in [1, 2]: return 1
        if y in [3, 4]: return 2
        if y in [5, 6]: return 3
        return 0
    
    @staticmethod
    def from_json(filepath):
        g = nx.Graph(name=filepath)
        with open(filepath, 'r') as f:
            j = json.load(f)
            for idx, node in enumerate(j['node_features']):
                g.add_node(idx, x=node, mask=j['node_masks'][idx], y=j['y'][idx], comm=j['comms'][idx], block=j['blocks'][idx], segment=j['segments'][idx], image_path=j['image_path'])
            for src in j['edge_attrs']:
                for tgt in j['edge_attrs'][src]:
                    g.add_edge(int(src), int(tgt), edge_attr=j['edge_attrs'][src][tgt], label=j['edge_labels'][src][tgt], peer=j['peers'][src][tgt])
        return g
    
    @staticmethod
    def to_data(filepath):
        with open(filepath, 'r') as f:
            j = json.load(f)
            block = torch.LongTensor(j['blocks'])
            segment = torch.LongTensor(j['segments'])
            y = torch.LongTensor(j['y'])
            x = j['node_features']
            x = [[i if i is not None else 0.0 for i in j] for j in x]
            x = torch.FloatTensor(x)
            pos = x[:, :pos_size]
            x = x[:, pos_size:]
            mask = torch.LongTensor(j['node_masks'])
            comm = torch.LongTensor(j['comms'])
            edge_attrs = j['edge_attrs']
            edge_attrs = [(int(s), int(t), a) for s, v in edge_attrs.items() for t, a in v.items()]
            edge_index = [[k[0], k[1]] for k in edge_attrs]  # + [[k[1], k[0]] for k in edge_attrs]
            edge_index = torch.LongTensor(edge_index).t()
            edge_attrs = [k[2] for k in edge_attrs]
            # edge_attrs = edge_attrs + edge_attrs
            edge_attrs = torch.FloatTensor(edge_attrs)
            edge_labels = j['edge_labels']
            edge_labels = [edge_label_to_idx[a] for s, v in edge_labels.items() for t, a in v.items()]
            # edge_labels = edge_labels + edge_labels
            edge_labels = torch.FloatTensor(edge_labels)
            edge_peers = j['peers']
            edge_peers = [a for s, v in edge_peers.items() for t, a in v.items()]
            # edge_peers = edge_peers + edge_peers
            edge_peers = torch.LongTensor(edge_peers)
            # print(x.shape, y.shape, mask.shape, comm.shape, edge_index.shape, edge_attrs.shape, edge_labels.shape, edge_peers.shape)
        return Data(x=x, edge_index=edge_index, edge_attr=edge_attrs, edge_label=edge_labels, y=y, pos=pos, mask=mask, peer=edge_peers, comm=comm, block=block, segment=segment, filepath=filepath)

    def load_training_graph(self, task_type, ocr_dir, images_dir, graphs_dir, filename, graph_type, n_comms, hierarchical):
        try:
            filepath = os.path.join(ocr_dir, filename)
            imagepath = os.path.join(images_dir, filename.replace('.json', self.img_extension))
            if task_type == 'node':
                words, image_width, image_height = self.load_words(filepath)
                w = image_width if image_width is not None else 0
                h = image_height if image_height is not None else 0
                df = get_pct_offsets(words, imagepath, w, h, True)
                if graph_type == 'alignet':
                    gs = [g for g in to_graph(df, self.get_label, True)]
                else:
                    gs = [g for g in to_graph_beta(df, self.get_label, True, w, h, graph_type, n_comms, hierarchical)]
            else:
                entities, relations = self.load_relations(filepath)
                gs = to_graph_relations(entities, relations, imagepath)
            for g in gs: 
                graph = self.to_json(g, imagepath, filename)
                # graph = self.convert_nxg_to_graph(g).to_json()
                with open(os.path.join(graphs_dir, filename + '.json'), 'w') as f:
                    json.dump(graph, f)
            return True
        except Exception as err:
            print('Unable to process file. Skipping', filepath, 'due to', traceback.format_exc())
    
    def load_training_graphs(self, home_dir, ocr_dir, images_dir, graphs_dir, n_pool, graph_type, task_type, n_comms, hierarchical):
        """
        Loads the FUNSD pages as NetworkX graphs.
        """
        files = [(task_type, ocr_dir, images_dir, graphs_dir, f, graph_type, n_comms, hierarchical) for f in list(os.listdir(ocr_dir))]  # if not os.path.exists(os.path.join(graphs_dir, f + '.json'))]
        with mp.Pool(n_pool) as p:
            gs_train = p.starmap(self.load_training_graph, tqdm(files, total=len(files)))
        print(len(gs_train))
        alignet = AligNet(root=home_dir, name=self.name)
        alignet.process()
        return True

    def order_nodes(self, graph):
        return graph


class FUNSD(Dataset):
    def __init__(self, home):
        super().__init__("FUNSD", home, '.png')
        self.label_to_idx = {
            "B-ANSWER": 5,
            "B-HEADER": 1,
            "B-QUESTION": 3,
            "I-ANSWER": 6,
            "I-HEADER": 2,
            "I-QUESTION": 4,
            "O": 0
        }
        self.idx_to_label = {v:k for k, v in self.label_to_idx.items()}
        self.labels = [self.idx_to_label[i] for i in range(len(self.label_to_idx))]
        self.num_classes = len(self.idx_to_label)
        self.train_home = os.path.join(self.home, 'training_data')
        self.val_home = os.path.join(self.home, 'training_data')
        self.test_home = os.path.join(self.home, 'testing_data')
        self.train_images = os.path.join(self.home, 'training_data', 'images')
        self.train_ocr = os.path.join(self.home, 'training_data', 'annotations')
        self.train_graphs = os.path.join(self.home, 'training_data', 'graphs')
        self.train_alignets = os.path.join(self.home, 'training_data', 'alignets') 
        self.val_images = os.path.join(self.home, 'training_data', 'images')
        self.val_ocr = os.path.join(self.home, 'training_data', 'annotations')
        self.val_graphs = os.path.join(self.home, 'training_data', 'graphs')
        self.val_alignets = os.path.join(self.home, 'training_data', 'alignets')
        self.test_images = os.path.join(self.home, 'testing_data', 'images')
        self.test_ocr = os.path.join(self.home, 'testing_data', 'annotations')
        self.test_graphs = os.path.join(self.home, 'testing_data', 'graphs')
        self.test_alignets = os.path.join(self.home, 'testing_data', 'alignets')
        self.metric = "seqeval"

    def get_label(self, s):
        x = str(s).strip()
        if x in self.label_to_idx: return self.label_to_idx[x]
        else: return 0
        
    def load_words(self, filepath):
        words = []
        with open(filepath, 'r') as f:
            data = json.load(f)
            block_num = -1
            for idx, item in enumerate(data['form']):
                block_num += 1
                label = item['label'].upper() if 'label' in item else ''
                left, top, right, bottom = item['box']
                for idxx, word in enumerate(item['words']):
                    if idxx == 0: 
                        l = 'B-' + label
                        segment = "B"
                    else: 
                        l = 'I-' + label
                        segment = "I"
                    # l = 'B-' + label
                    if 'OTHER' in l: 
                        l = 'O'
                    if 'text' not in word: continue
                    left, top, right, bottom = word['box']
                    w = Word(word['text'], left, top, right, bottom, block_num, segment, l)
                    words.append(w)     
        return words, 0, 0

    def load_relations(self, filepath):
        entities = []
        relations = []
        with open(filepath, 'r') as f:
            data = json.load(f)
            block_num = -1
            for idx, entity in enumerate(data['form']):
                block_num += 1
                # label = item['label'].upper() if 'label' in item else ''
                label = entity['label'].upper() if 'label' in entity else ''
                linking = entity['linking']
                relations += [r for r in linking if r not in relations]
                left, top, right, bottom = entity['box']
                w = Word(entity['text'], left, top, right, bottom, block_num, block_num, label)
                entities.append(w) 
        return entities, relations


class IDL(Dataset):
    def __init__(self, home):
        super().__init__("IDL", home, '.tif')
        self.labels = ['SEGMENT']
        self.num_classes = len(self.labels)
        self.train_home = os.path.join(self.home, 'train')
        self.train_images = os.path.join(self.home, 'train', 'images')
        self.train_ocr = os.path.join(self.home, 'train', 'ocr')
        self.train_graphs = os.path.join(self.home, 'train', 'graphs')
        self.train_alignets = os.path.join(self.home, 'train', 'alignets')
        self.metric = None

    def get_label(self, s):
        return 0
        
    def load_words(self, filepath):
        words = []
        with open(filepath, 'r') as f:
            data = json.load(f)
            filepath = data[0]
            doc = data[1]
            pages, lines, wrds = {}, {}, {}
            for block in doc['Blocks']:
                if block['BlockType'] == 'PAGE':
                    pages[block['Id']] = block
                if block['BlockType'] == 'LINE':
                    lines[block['Id']] = block
                if block['BlockType'] == 'WORD':
                    wrds[block['Id']] = block
            block_num = -1
            for line_id, line in lines.items():
                block_num += 1
                for relationship in line['Relationships']:
                    if relationship['Type'] != 'CHILD': continue
                    word_ids = relationship['Ids']
                    ws = [wrds[wid] for wid in word_ids if wid in wrds]
                    for idx, wrd in enumerate(ws):
                        if idx == 0:
                            l = 'B-SEGMENT'
                            segment = 'B'
                        else:
                            l = 'I-SEGMENT'
                            segment = 'I'
                        bbox = wrd['Geometry']['BoundingBox']
                        left, top , right, bottom = bbox['Left'], bbox['Top'], bbox['Left']+bbox['Width'], bbox['Top'] + bbox['Height']
                        w = Word(wrd['Text'], left, top, right, bottom, block_num, segment, l)
                        words.append(w)       
        return words, 1, 1


class DocVQA(Dataset):
    def __init__(self, home):
        super().__init__("DocVQA", home, '.png')
        self.home = home
        self.train_home = self.home
        self.val_home = self.home
        self.test_home = self.home
        self.train_images = os.path.join(self.home, 'images')
        self.train_ocr = os.path.join(self.home, 'ocr')
        self.train_graphs = os.path.join(self.home, 'graphs')
        self.train_alignets = os.path.join(self.home, 'alignets')
        self.val_images = os.path.join(self.home, 'images')
        self.val_ocr = os.path.join(self.home, 'ocr')
        self.val_graphs = os.path.join(self.home, 'graphs')
        self.val_alignets = os.path.join(self.home, 'alignets')
        self.test_images = os.path.join(self.home, 'images')
        self.test_ocr = os.path.join(self.home, 'ocr')
        self.test_graphs = os.path.join(self.home, 'graphs')
        self.test_alignets = os.path.join(self.home, 'alignets')
        self.num_classes = 0
        self.labels = []
        self.metric = None

    def get_label(self, s):
        return 0

    def load_words(self, filepath):
        with open(filepath, 'r') as f:
            data = json.load(f)['recognitionResults'][0]
            block_num = -1
            words = []
            lines = data['lines']
            for line in lines:
                block_num += 1
                # words = [w for line in lines for w in line['words']]
                words += [Word(w['text'], w['boundingBox'][0], w['boundingBox'][1], w['boundingBox'][4], w['boundingBox'][5], block_num, None, None) for w in line['words']]
        return words, 0, 0
    

class RVL(Dataset):
    def __init__(self, home):
        super().__init__("RVL", home, '.jpeg')
        self.home = home
        self.train_home = os.path.join(self.home, 'train')
        self.val_home = os.path.join(self.home, 'val')
        self.test_home = os.path.join(self.home, 'test')
        self.train_images = os.path.join(self.home, 'train', 'images')
        self.train_ocr = os.path.join(self.home, 'train', 'ocr')
        self.train_graphs = os.path.join(self.home, 'train', 'graphs')
        self.train_alinets = os.path.join(self.home, 'train', 'alignets')
        self.val_images = os.path.join(self.home, 'val', 'images')
        self.val_ocr = os.path.join(self.home, 'val', 'ocr')
        self.val_graphs = os.path.join(self.home, 'val', 'graphs')
        self.val_alinets = os.path.join(self.home, 'val', 'alignets')
        self.test_images = os.path.join(self.home, 'test', 'images')
        self.test_ocr = os.path.join(self.home, 'test', 'ocr')
        self.test_graphs = os.path.join(self.home, 'test', 'graphs')
        self.test_alinets = os.path.join(self.home, 'test', 'alignets')
        self.num_classes = 16
        self.labels = [i for i in range(0, self.num_classes)]
        self.metric = "seqeval"

    def get_label(self, s):
        return 0

    def load_words(self, filepath):
        word_seen = {}
        with open(filepath, 'r') as f:
            data = json.load(f)
            block_num = 0
            words = []
            assert len(data['tokens']) == len(data['bboxes'])
            for word, bbox in zip(data['tokens'], data['bboxes']):
                # words = [w for line in lines for w in line['words']]
                # skip misrecognized words:
                if bbox[0] == 0 and bbox[1] == 0 and bbox[2] == data['image_dims'][0] and bbox[3] == data['image_dims'][1]: continue
                # skip overlapping words:
                bb = "{}_{}_{}_{}".format(bbox[0], bbox[1], bbox[2], bbox[3])
                if bb in word_seen: continue
                else: word_seen[bb] = word
                words += [Word(word, bbox[0], bbox[1], bbox[2], bbox[3], block_num, None, None)]
                
            image_width, image_height = data['image_dims']
        return words, image_width, image_height
    

class SROIE(Dataset):
    def __init__(self, home):
        super().__init__("SROIE", home, '.jpg')
        self.home = home
        self.train_home = os.path.join(self.home, 'train')
        self.val_home = os.path.join(self.home, 'train')
        self.test_home = os.path.join(self.home, 'test')
        self.train_images = os.path.join(self.home, 'train', 'images')
        self.train_ocr = os.path.join(self.home, 'train', 'tagged')
        self.train_graphs = os.path.join(self.home, 'train', 'graphs')
        self.train_alignets = os.path.join(self.home, 'train', 'alignets')
        self.val_images = os.path.join(self.home, 'train', 'images')
        self.val_ocr = os.path.join(self.home, 'train', 'tagged')
        self.val_graphs = os.path.join(self.home, 'train', 'graphs')
        self.val_alignets = os.path.join(self.home, 'train', 'alignets')
        self.test_images = os.path.join(self.home, 'test', 'images')
        self.test_ocr = os.path.join(self.home, 'test', 'tagged')
        self.test_graphs = os.path.join(self.home, 'test', 'graphs')
        self.test_alignets = os.path.join(self.home, 'test', 'alignets')
        self.labels = ["O","B-COMPANY", "I-COMPANY", "B-DATE", "I-DATE", "B-ADDRESS", "I-ADDRESS", "B-TOTAL", "I-TOTAL"]
        self.num_classes = len(self.labels)
        self.idx_to_label = {i:self.labels[i] for i in range(self.num_classes)}
        self.label_to_odx = {self.labels[i]:i for i in range(self.num_classes)}
        self.metric = "seqeval"

    def get_label(self, s):
        if s.upper() in self.labels: return self.labels.index(s)
        return -1 

    def load_words(self, filepath):
        with open(filepath, 'r') as f:
            data = json.load(f)
            words, bboxes, labels = data['words'], data['bbox'], data['labels']
            block_num = 0
            output = []
            for idx, (word, bbox, label) in enumerate(zip(words, bboxes, labels)):
                if label.startswith("B"): 
                    block_num += 1
                    segment = "B"
                elif label == "O" and idx > 0 and labels[idx-1] != "O":
                    segment = "B"
                else:
                    segment = "I" 
                output += [Word(word, bbox[0], bbox[1], bbox[2], bbox[3], block_num, segment, label)]
        return output, 0, 0


class CORD(Dataset):
    def __init__(self, home):
        super().__init__("CORD", home, '.png')
        self.home = home
        self.train_home = os.path.join(self.home, 'train')
        self.val_home = os.path.join(self.home, 'validation')
        self.test_home = os.path.join(self.home, 'test')
        self.train_images = os.path.join(self.home, 'train', 'images')
        self.train_ocr = os.path.join(self.home, 'train', 'ocr')
        self.train_graphs = os.path.join(self.home, 'train', 'graphs')
        self.train_alignets = os.path.join(self.home, 'train', 'alignets')
        self.val_images = os.path.join(self.home, 'validation', 'images')
        self.val_ocr = os.path.join(self.home, 'validation', 'ocr')
        self.val_graphs = os.path.join(self.home, 'validation', 'graphs')
        self.val_alignets = os.path.join(self.home, 'validation', 'alignets')
        self.test_images = os.path.join(self.home, 'test', 'images')
        self.test_ocr = os.path.join(self.home, 'test', 'ocr')
        self.test_graphs = os.path.join(self.home, 'test', 'graphs')
        self.test_alignets = os.path.join(self.home, 'test', 'alignets')
        self.labels = ['menu.cnt',
                        'menu.discountprice',
                        'menu.etc',
                        'menu.itemsubtotal',
                        'menu.nm',
                        'menu.num',
                        'menu.price',
                        'menu.sub_cnt',
                        'menu.sub_etc',
                        'menu.sub_nm',
                        'menu.sub_price',
                        'menu.sub_unitprice',
                        'menu.unitprice',
                        'menu.vatyn',
                        'sub_total.discount_price',
                        'sub_total.etc',
                        'sub_total.othersvc_price',
                        'sub_total.service_price',
                        'sub_total.subtotal_price',
                        'sub_total.tax_price',
                        'total.cashprice',
                        'total.changeprice',
                        'total.creditcardprice',
                        'total.emoneyprice',
                        'total.menuqty_cnt',
                        'total.menutype_cnt',
                        'total.total_etc',
                        'total.total_price',
                        'void_menu.nm',
                        'void_menu.price']
        self.num_classes = len(self.labels)
        self.idx_to_label = {i:self.labels[i] for i in range(0, self.num_classes)}
        self.label_to_idx = {self.labels[i]:i for i in range(0, self.num_classes)}
        self.metric = "seqeval"

    def get_label(self, s):
        return s

    def load_words(self, filepath):
        with open(filepath, 'r') as f:
            data = json.load(f)
            output = []
            block_num = 0
            for line in data['valid_line']:
                l = line['category']
                if l not in self.label_to_idx: 
                    i = l.rfind('.')
                    l = l[:i] + '_' + l[i+1:]
                # if l == 'menu.sub.nm': l = 'menu.sub_nm'
                # if l == 'menu.sub.price': l = 'menu.sub_price'
                label = self.label_to_idx[l]
                for idx, word in enumerate(line['words']):
                    segment = "B" if idx == 0 else "I"
                    w, bbox = word['text'], [word['quad']['x1'], word['quad']['y1'], word['quad']['x4'], word['quad']['y4']]
                    output += [Word(w, bbox[0], bbox[1], bbox[2], bbox[3], block_num, segment, label)]
                block_num += 1
        return output, 0, 0
    

class BUDDIE(Dataset):
    def __init__(self, home):
        super().__init__("BUDDIE", home, '.jpg')
        self.home = home
        self.train_home = os.path.join(self.home, 'train')
        self.val_home = os.path.join(self.home, 'val')
        self.test_home = os.path.join(self.home, 'test')
        self.train_images = os.path.join(self.home, 'train', 'images')
        self.train_ocr = os.path.join(self.home, 'train', 'ocr')
        self.train_graphs = os.path.join(self.home, 'train', 'graphs')
        self.val_images = os.path.join(self.home, 'val', 'images')
        self.val_ocr = os.path.join(self.home, 'val', 'ocr')
        self.val_graphs = os.path.join(self.home, 'val', 'graphs')
        self.test_images = os.path.join(self.home, 'test', 'images')
        self.test_ocr = os.path.join(self.home, 'test', 'ocr')
        self.test_graphs = os.path.join(self.home, 'test', 'graphs')
        self.num_classes = 70
        self.idx_to_label = {i:i for i in range(0, self.num_classes)}
        self.labels = [i for i in range(0, self.num_classes)]
        self.metric = "seqeval"

    def get_label(self, s):
        return s

    def load_words(self, filepath):
        with open(filepath, 'r') as f:
            data = json.load(f)
            output = []
            block_num = 0
            output = []
            for idx, token in enumerate(data['tokens']):
                if idx > 0 and data['tokens'][idx-1]['class_id'] != token['class_id']:
                    block_num += 1
                    segment = "B"
                else:
                    segment = "I"
                output += [Word(token['text'],token['x'], token['y'], token['x']+token['width'], token['y']+token['height'], block_num, segment, token['class_id']+1)]
        return output, 0, 0
    

class AligNet(DT):
    def __init__(self, root, name, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.name = name

    @property
    def raw_dir(self) -> str:
        return osp.join(self.root, 'graphs')

    @property
    def processed_dir(self) -> str:
        return osp.join(self.root, 'alignets')

    @property
    def raw_file_names(self):
        return [osp.join(self.raw_dir, f) for f in os.listdir(self.raw_dir) if f.endswith('.json')]

    @property
    def processed_file_names(self):
        return [osp.join(self.processed_dir, f) for f in os.listdir(self.processed_dir) if f.endswith('.pt') and f.startswith('data_')]

    def download(self):
        pass
        # # Download to `self.raw_dir`.
        # path = download_url(url, self.raw_dir)
        # ...

    def process(self):
        idx = 0
        for raw_path in tqdm(self.raw_file_names):
            try:
                # Read data from `raw_path`.
                data = Dataset.to_data(raw_path)

                torch.save(data, osp.join(self.processed_dir, f'data_{idx}.pt'))
                idx += 1
            except json.decoder.JSONDecodeError:
                print('Invalid json. Skipping', raw_path)
            except IndexError:
                print('Too many indices for tensor of dimension 1', raw_path)

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))
        return data

    
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='AliGATr preprocessor.')
    parser.add_argument('-d','--dataset', type=str, help='Name of datasett to be used for pretraining', choices=['idl', 'rvl-cdip', 'funsd', 'cord', 'sroie', 'buddie'], default='funsd')
    parser.add_argument('-w','--num_workers', type=int, help='Number of workers for DataListLoader', default=8)
    parser.add_argument('-g','--graph_type', type=str, help='Whether the data should be processed into AligNet or Beta-skeleton graphs', choices=['alignet', 'beta'], default='alignet')
    parser.add_argument('-t','--task_type', type=str, help='Whether the ultimate task is node classification (entity tagging) or link detection (relation identification).', choices=['node', 'link'], default='alignet')
    parser.add_argument('-n','--n_comms', type=int, help='Indicates the number of communities', default=10)
    parser.add_argument('-r','--hierarchical', type=bool, help='Indicates whether the graph should be constructed hierarchically', default=False)

    args = vars(parser.parse_args())

    dataset = FUNSD(funsd_home)
    if args['dataset'] == 'idl':
        dataset = IDL(idl_home)
    if args['dataset'] == 'rvl-cdip':
        dataset = RVL(rvl_home)
    if args['dataset'] == 'cord':
        dataset = CORD(cord_home)
    if args['dataset'] == 'sroie':
        dataset = SROIE(sroie_home)
    if args['dataset'] == 'buddie':
        dataset = BUDDIE(buddie_home)
    if args['dataset'] == 'docvqa':
        dataset = DocVQA(docvqa_home)

    # n_pool, graph_type, n_comms, hierarchical
    print(int(args['num_workers']), args['graph_type'], args['task_type'], int(args['n_comms']), args['hierarchical'])
    dataset.process(int(args['num_workers']), args['graph_type'], args['task_type'], int(args['n_comms']), args['hierarchical'])