import sys
from os import path

from torch.utils.data import Dataset

sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
import argparse
import json
import os
import pickle
import platform
from typing import Any, List
import shutil
import lmdb
from tqdm import tqdm
from torch.utils.data import DataLoader

# fmt: off
parser = argparse.ArgumentParser("Serialize a VINVL feature file to LMDB.")
parser.add_argument(
    "-d", "--data-root", default="/data/VQA/features/",
    help="Path to the root directory of COCO dataset.",
)
parser.add_argument(
    "-b", "--batch-size", type=int, default=1024,
    help="Batch size to process and serialize data. Set as per CPU memory.",
)
parser.add_argument(
    "-j", "--cpu-workers", type=int, default=32,
    help="Number of CPU workers for data loading.",
)

class LmdbReader(Dataset):
    r"""
    A reader interface to read datapoints from a serialized LMDB file containing
    ``(image_id, image, caption)`` tuples. Optionally, one may specify a
    partial percentage of datapoints to use.

    .. note::

        When training in distributed setting, make sure each worker has SAME
        random seed because there is some randomness in selecting keys for
        training with partial dataset. If you wish to use a different seed for
        each worker, select keys manually outside of this class and use
        :meth:`set_keys`.

    .. note::

        Similar to :class:`~torch.utils.data.distributed.DistributedSampler`,
        this reader can shuffle the dataset deterministically at the start of
        epoch. Use :meth:`set_shuffle_seed` manually from outside to change the
        seed at every epoch.

    Parameters
    ----------
    lmdb_path: str
        Path to LMDB file with datapoints.
    shuffle: bool, optional (default = True)
        Whether to shuffle or not. If this is on, there will be one deterministic
        shuffle based on epoch before sharding the dataset (to workers).
    percentage: float, optional (default = 100.0)
        Percentage of datapoints to use. If less than 100.0, keys will be
        shuffled and first K% will be retained and use throughout training.
        Make sure to set this only for training, not validation.
    """

    def __init__(self, lmdb_path: str, img_ids: list):
        self.lmdb_path = lmdb_path

        # fmt: off
        # Create an LMDB transaction right here. It will be aborted when this
        # class goes out of scope.
        env = lmdb.open(
            self.lmdb_path, subdir=False, readonly=True, lock=False,
            readahead=False, map_size=1099511627776 * 2,
        )
        self.db_txn = env.begin()

        # Form a list of LMDB keys numbered from 0 (as binary strings).
        self.num_data = env.stat()["entries"]

        self.img_id = img_ids
    def __getstate__(self):
        r"""
        This magic method allows an object of this class to be pickable, useful
        for dataloading with multiple CPU workers. :attr:`db_txn` is not
        pickable, so we remove it from state, and re-instantiate it in
        :meth:`__setstate__`.
        """
        state = self.__dict__
        state["db_txn"] = None
        return state
    def __setstate__(self, state):
        self.__dict__ = state

        env = lmdb.open(
            self.lmdb_path, subdir=False, readonly=True, lock=False,
            readahead=False, map_size=1099511627776 * 2,
        )
        self.db_txn = env.begin()

    def __len__(self):
        return self.num_data

    def __getitem__(self, i: int):

        return self.get_image_feature(i)

    def get_image_feature(self, i: int):
        encoded_id = str(self.img_id[i]).encode("ascii")
        datapoint_pickled = self.db_txn.get(encoded_id)
        return {"img_id":self.img_id[i], "key":encoded_id, "value":datapoint_pickled}

def collate_fn(instances: List[Any]):
    r"""Collate function for data loader to return list of instances as-is."""
    return instances


if __name__ == "__main__":

    _A = parser.parse_args()
    _A.output_path = _A.data_root + "/train/"
    _A.output_file = _A.output_path + f"features_oscar_split.lmdb"
    os.makedirs(os.path.dirname(_A.output_path), exist_ok=True)

    _A.train_orig_file = _A.data_root + "/train/features.lmdb"
    _A.valid_orig_file = _A.data_root + "/val/features.lmdb"

    _A.train_q_file_oscar_split = _A.data_root + "/train/train_questions_oscar_split.json"
    _A.valid_q_file_oscar_split = _A.data_root + "/val/val_questions_oscar_split.json"
    _A.valid_q_file_orign_split = _A.data_root + "/val/v2_OpenEnded_mscoco_val2014_questions.json"

    train_os_q = json.load(open(_A.train_q_file_oscar_split))
    valid_os_q = json.load(open(_A.valid_q_file_oscar_split))
    valid_orign_q = json.load(open(_A.valid_q_file_orign_split))

    train_os_img_ids = sorted(list(set([d["image_id"] for d in train_os_q["questions"]])))
    valid_os_img_ids = sorted(list(set([d["image_id"] for d in valid_os_q["questions"]])))
    valid_orign_img_ids = sorted(list(set([d["image_id"] for d in valid_orign_q["questions"]])))

    assert len(list(set(train_os_img_ids) & set(valid_os_img_ids))) == 0, "Some val img is in the train set."

    # Set a sufficiently large map size for LMDB (based on platform).
    map_size = 1099511627776 * 2 if platform.system() == "Linux" else 30000000000

    # ------------------------ copy train orig feature ------------------------ #
    # Copy train files and read it to data to save
    if os.path.isfile(_A.output_file):
        print("train file copy already exist")
    else:
        shutil.copyfile(
            _A.train_orig_file,
            _A.output_file
        ) # write to the disk.

        shutil.copyfile(
            _A.train_orig_file+"-lock",
            _A.output_file+"-lock"
        )  # write to the disk.
        print("Copied train feature file")

    # Open an LMDB database.
    db_train = lmdb.open(
        _A.output_file, map_size=map_size, subdir=False, meminit=False, map_async=True
    )

    # ------------------------ valid orig feature iteration ------------------------ #
    _A.output_path = _A.data_root + "/val/"
    _A.output_file = _A.output_path + f"features_oscar_split.lmdb"
    os.makedirs(os.path.dirname(_A.output_path), exist_ok=True)

    db_val = lmdb.open(
        _A.output_file, map_size=map_size, subdir=False, meminit=False, map_async=True
    )

    # modify lmdb of valid file and merge some part to feature_os
    dloader = DataLoader(
        LmdbReader(_A.valid_orig_file, valid_orign_img_ids),
        batch_size=_A.batch_size,
        num_workers=_A.cpu_workers,
        shuffle=False,
        drop_last=False,
        collate_fn=collate_fn
    )

    for idx, batch in enumerate(tqdm(dloader)):
        txn_train = db_train.begin(write=True)  # start writing/reading
        txn_valid = db_val.begin(write=True)  # start writing/reading

        for instance in batch:
            if instance["img_id"] in train_os_img_ids:
                txn_train.put(
                    instance["key"],
                    instance["value"]
                )  # update db

            elif instance["img_id"] in valid_os_img_ids:
                txn_valid.put(
                    instance["key"],
                    instance["value"]
                )  # update db

            else:
                ValueError(f"No matching dataset split : img id ({instance['img_id']})")

        txn_train.commit()  # write to the disk.
        txn_valid.commit()  # write to the disk.

    db_train.sync()
    db_train.close()

    db_val.sync()
    db_val.close()
