import sys
from os import path
sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
import argparse
import json
import os
import pickle
import platform
from typing import Any, List

import lmdb
from tqdm import tqdm
from torch.utils.data import DataLoader

from lmdb_readers import VQA_pt_Reader


# fmt: off
parser = argparse.ArgumentParser("Serialize a VINVL feature file to LMDB.")
parser.add_argument(
    "-d", "--data-root", default="/data/VQA/features/train/inference/vinvl_vg_x152c4",
    help="Path to the root directory of COCO dataset.",
)
parser.add_argument(
    "-b", "--batch-size", type=int, default=1024,
    help="Batch size to process and serialize data. Set as per CPU memory.",
)
parser.add_argument(
    "-j", "--cpu-workers", type=int, default=32,
    help="Number of CPU workers for data loading.",
)


def collate_fn(instances: List[Any]):
    r"""Collate function for data loader to return list of instances as-is."""
    return instances


if __name__ == "__main__":

    _A = parser.parse_args()
    _A.output_path = _A.data_root + "/lmdb/"
    _A.output_file = _A.output_path + f"features.lmdb"
    os.makedirs(os.path.dirname(_A.output_path), exist_ok=True)

    dloader = DataLoader(
        VQA_pt_Reader(_A.data_root),
        batch_size=_A.batch_size,
        num_workers=_A.cpu_workers,
        shuffle=False,
        drop_last=False,
        collate_fn=collate_fn
    )
    # Open an LMDB database.
    # Set a sufficiently large map size for LMDB (based on platform).
    map_size = 1099511627776 * 2 if platform.system() == "Linux" else 30000000000
    db = lmdb.open(
        _A.output_file, map_size=map_size, subdir=False, meminit=False, map_async=True
    )

    # Serialize each instance (as a dictionary). Use `pickle.dumps`. Key will
    # be an integer (cast as string) starting from `0`.
    for idx, batch in enumerate(tqdm(dloader)):
        txn = db.begin(write=True) # start writing/reading
        for instance in batch:
            txn.put(
                instance["img_id"].encode("ascii"),
                pickle.dumps(instance, protocol=-1)
            ) # update db
        txn.commit() # write to the disk.

    db.sync()
    db.close()
