import numpy as np
import json
import array
import collections
import torch.utils.data
import lmdb
from unilm.utils import deserialize_str


class PretrainDatasetForUnilm(torch.utils.data.Dataset):
    def __init__(self, feature_data_file, feature_format_file, list_style=True, data_type='h'):
        with open(feature_data_file, mode='rb') as reader:
            self.binary_features = reader.read()
        self.constructor = BinaryDataset.recover_from_config_file(
            feature_format_file, data_type=data_type)
        a = array.array(data_type)
        a.append(0)
        self.data_type_size = len(a.tobytes())
        self.instance_size = self.constructor.instance_size * self.data_type_size
        assert len(self.binary_features) % self.instance_size == 0
        self.list_style = list_style

    def __len__(self):
        return len(self.binary_features) // self.constructor.instance_size // self.data_type_size

    def __getitem__(self, idx):
        offset = idx * self.instance_size
        bytes_data = self.binary_features[offset:offset + self.instance_size]
        return self.constructor.recover_from_bytes(bytes_data, list_style=self.list_style)

    def drop_features(self, num_keep):
        offset = num_keep * self.instance_size
        self.binary_features = self.binary_features[:offset]

    def get_keys(self):
        return list(self.constructor.format_offset.keys())


class DocDB(object):
    def __init__(self, db_path):
        self.db_path = db_path
        self.env = lmdb.open(db_path, readonly=True, lock=False, readahead=False, meminit=False)
        with self.env.begin(write=False) as txn:
            self.size = int(deserialize_str(txn.get(b'__size__')))

    def __getitem__(self, example_index):
        with self.env.begin(write=False) as txn:
            example = json.loads(deserialize_str(txn.get(b"table_%d" % example_index)))
        return example

    def __len__(self):
        return self.size


class BinaryDataset(object):

    def __init__(self, data_type='h'):
        self.format_offset = collections.OrderedDict()
        self.format_shape = collections.OrderedDict()
        self.format_size = collections.OrderedDict()
        self.format_flag = False
        self.instance_size = 0
        a = array.array(data_type)
        a.append(0)
        self.data_type_size = len(a.tobytes())
        self.data_type = data_type
        print("  *********  Data type = %s ********* " % self.data_type)
        print("  *********  Data type size = %d ********* " % self.data_type_size)

    def _create_rules(self, features):
        offset = 0
        assert isinstance(features, dict)
        assert self.format_flag is False
        for name in features:
            feature = features[name]
            v = np.array(feature)
            shape = v.shape
            size = v.size
            self.format_shape[name] = shape
            self.format_size[name] = size
            self.format_offset[name] = offset
            offset += size
        self.instance_size = offset

    @classmethod
    def recover_from_config_file(cls, config_file, data_type='h'):
        recover = cls(data_type=data_type)
        with open(config_file, mode="r", encoding="utf-8") as f_in:
            config = json.loads(f_in.read())
            for key, offset in sorted(config["offset"].items(), key=lambda x: x[1]):
                recover.format_offset[key] = offset
                recover.format_shape[key] = config["shape"][key]
                recover.format_size[key] = config["size"][key]
                recover.instance_size += config["size"][key]
        return recover

    def save_config(self, config_file):
        with open(config_file, mode="w", encoding="utf-8") as writer:
            to_save = {
                "offset": self.format_offset,
                "shape": self.format_shape,
                "size": self.format_size,
            }
            writer.write(json.dumps(to_save, indent=2))

    def recover_from_bytes(self, bytes_data, list_style=True):
        if len(bytes_data) != self.data_type_size * self.instance_size:
            return None
        data = array.array(self.data_type)
        data.frombytes(bytes_data)
        assert len(data) == self.instance_size

        rets = [] if list_style else collections.OrderedDict()
        for name in self.format_offset:
            offset = self.format_offset[name]
            size = self.format_size[name]
            shape = self.format_shape[name]
            f = np.array(data[offset:offset + size])
            if len(shape) > 0:
                f = f.reshape(shape)
            if list_style:
                rets.append(f)
            else:
                rets[name] = f

        return rets

    def recover_from_file(self, file_name, list_style=True):
        num_bytes = self.instance_size * self.data_type_size
        instances = []

        with open(file_name, mode='rb') as f_in:
            all_data = f_in.read()

        offset = 0
        while offset < len(all_data):
            bytes_data = all_data[offset:offset + num_bytes]
            inst = self.recover_from_bytes(bytes_data, list_style)
            if inst:
                if not list_style:
                    inst["feature_id"] = len(instances)
                instances.append(inst)
            offset += num_bytes

        return instances

    def binary(self, features):
        if not self.format_flag:
            self._create_rules(features)
            self.format_flag = True

        data = array.array(self.data_type)

        for name in self.format_offset:
            v = np.array(features[name])
            for _ in v.reshape(-1):
                data.append(_)
                # assert (-32768 <= _ < 32768)
            if v.size != self.format_size[name]:
                raise RuntimeError()

        if len(data) != self.instance_size:
            raise RuntimeError()

        return data.tobytes()
