r"""
A *Reader* is a PyTorch :class:`~torch.utils.data.Dataset` which simply reads
data from disk and returns it almost as is. Readers defined here are used by
datasets in :mod:`virtex.data.datasets`.
"""
from collections import defaultdict
import glob
import json
import os
import pickle
import random
from typing import Dict, List, Tuple

import base64
import lmdb
import torch
import numpy as np
import pandas as pd
from loguru import logger
from torch.utils.data import Dataset

# Some simplified type renaming for better readability
ImageID = int
Captions = List[str]


class VQA_pt_Reader(Dataset):
    def __init__(self, root: str = "data/vqa"):

        image_dir = os.path.join(root, f"predictions.tsv")
        self.predictions = pd.read_csv(image_dir, sep='\t', header=None)

    def __len__(self):
        return len(self.predictions)

    def add_box_position_to_feature(self, feature, boxes, h, w):
        # feature: box feature
        # h, w : image height and width

        xyxydxdy = np.array([
            boxes[0] / w, boxes[1] / h,
            boxes[2] / w, boxes[3] / h,
            abs(boxes[0] - boxes[2]) / w,
            abs(boxes[1] - boxes[3]) / h
        ], dtype=np.float32)
        feature = np.concatenate((feature, xyxydxdy), axis=0)

        return feature

    def per_obj_to_cls(self, pred):
        result = {key:[] for key in pred[0].keys()}
        for p in pred:
            for k in p.keys():
                if k == "feature":
                    result[k] += [
                        np.frombuffer(
                            base64.b64decode(p[k]),
                            np.float32
                        )
                    ]
                else:
                    result[k] += [p[k]]

        result["feature"] = np.array(result["feature"])
        return result

    def __getitem__(self, idx: int) -> Dict:
        # get instance from pred file
        img_id = str(self.predictions.iloc[idx][0])
        pred_per_img = json.loads(self.predictions.iloc[idx][1])
        instance = self.per_obj_to_cls(pred_per_img)
        instance.update({"img_id":img_id})

        return instance

class LmdbReader(Dataset):
    r"""
    A reader interface to read datapoints from a serialized LMDB file containing
    ``(image_id, image, caption)`` tuples. Optionally, one may specify a
    partial percentage of datapoints to use.

    .. note::

        When training in distributed setting, make sure each worker has SAME
        random seed because there is some randomness in selecting keys for
        training with partial dataset. If you wish to use a different seed for
        each worker, select keys manually outside of this class and use
        :meth:`set_keys`.

    .. note::

        Similar to :class:`~torch.utils.data.distributed.DistributedSampler`,
        this reader can shuffle the dataset deterministically at the start of
        epoch. Use :meth:`set_shuffle_seed` manually from outside to change the
        seed at every epoch.

    Parameters
    ----------
    lmdb_path: str
        Path to LMDB file with datapoints.
    shuffle: bool, optional (default = True)
        Whether to shuffle or not. If this is on, there will be one deterministic
        shuffle based on epoch before sharding the dataset (to workers).
    percentage: float, optional (default = 100.0)
        Percentage of datapoints to use. If less than 100.0, keys will be
        shuffled and first K% will be retained and use throughout training.
        Make sure to set this only for training, not validation.
    """

    def __init__(self, lmdb_path: str):
        self.lmdb_path = lmdb_path

        # fmt: off
        # Create an LMDB transaction right here. It will be aborted when this
        # class goes out of scope.
        env = lmdb.open(
            self.lmdb_path, subdir=False, readonly=True, lock=False,
            readahead=False, map_size=1099511627776 * 2,
        )
        self.db_txn = env.begin()

        # Form a list of LMDB keys numbered from 0 (as binary strings).
        self.num_data = env.stat()["entries"]

    def __getstate__(self):
        r"""
        This magic method allows an object of this class to be pickable, useful
        for dataloading with multiple CPU workers. :attr:`db_txn` is not
        pickable, so we remove it from state, and re-instantiate it in
        :meth:`__setstate__`.
        """
        state = self.__dict__
        state["db_txn"] = None
        return state
    def __setstate__(self, state):
        self.__dict__ = state

        env = lmdb.open(
            self.lmdb_path, subdir=False, readonly=True, lock=False,
            readahead=False, map_size=1099511627776 * 2,
        )
        self.db_txn = env.begin()

    def __len__(self):
        return self.num_data

    def __getitem__(self, img_id: str):

        return self.get_image_feature(img_id)

    def get_image_feature(self, img_id: str):
        datapoint_pickled = self.db_txn.get(img_id.encode("ascii"))
        image_feature = pickle.loads(datapoint_pickled)

        return image_feature

