import os
import pickle

import h5py

from coli.torch_span.parser import SpanParser
from coli.data_utils.dataset import SentenceBucketsBase
from coli.basic_tools.logger import default_logger


class TensorCacheManagerBase(object):
    def __init__(self, path, dataset_name, rest_dims, mode="a"):
        self.path = path
        self.f = h5py.File(path, mode)
        self.dataset = self.f.get(dataset_name)
        if self.dataset is None:
            self.dataset = self.f.create_dataset(
                dataset_name, shape=(100, *rest_dims),
                maxshape=(None, *rest_dims),
                dtype="f")
            self.dataset_pointer = 0
            self.sentence_map = {}
        else:
            try:
                with open(path + ".meta.pkl", "rb") as f:
                    self.dataset_pointer, self.sentence_map = pickle.load(f)
            except:
                self.dataset_pointer = 0
                self.sentence_map = {}

    def add_tensor(self, tensor):
        tensor_dim_1 = tensor.shape[0]
        if self.dataset.shape[0] < self.dataset_pointer + tensor_dim_1:
            self.dataset.resize(self.dataset_pointer + tensor_dim_1, 0)
        self.dataset[self.dataset_pointer:self.dataset_pointer + tensor_dim_1] = tensor
        old_pointer = self.dataset_pointer
        self.dataset_pointer += tensor_dim_1
        return old_pointer, self.dataset_pointer

    def save(self):
        self.f.flush()
        with open(self.path + ".meta.pkl", "wb") as f:
            pickle.dump((self.dataset_pointer, self.sentence_map), f)


class SpanCacheManager(TensorCacheManagerBase):
    def __init__(self, path, mode="a", dataset_name="bilm_cache",
                 span_dim=1024
                 ):
        super(SpanCacheManager, self).__init__(path, dataset_name, (span_dim,), mode)

    def add_sentences(self,
                      parser: SpanParser,
                      sentence_bucket: SentenceBucketsBase,
                      train_mode,
                      logger=default_logger
                      ):
        total_processed = 0
        for sent_feature, span_features in parser.get_parsed(
                sentence_bucket, return_span_features=True, train_mode=train_mode):
            sent_id = sent_feature.original_obj.extra["ID"]
            start, end = self.add_tensor(span_features)
            self.sentence_map[sent_id] = (start, end)
            total_processed += 1
            if total_processed % 500 == 0:
                logger.info("Span Cache: {} processed".format(total_processed))

    @classmethod
    def generate_cache_file(cls, parser,
                            cache_path, train_bucket, dev_buckets,
                            training_features=True
                            ):
        cache_manager = SpanCacheManager(cache_path + ".tmp",
                                         span_dim=parser.hparams.d_model)
        cache_manager.add_sentences(parser, train_bucket, training_features)
        if dev_buckets is not None:
            for dev_bucket in dev_buckets:
                cache_manager.add_sentences(parser, dev_bucket, False)
        cache_manager.save()
        os.rename(cache_path + ".tmp", cache_path)
        os.rename(cache_path + ".tmp.meta.pkl", cache_path + ".meta.pkl")

    def get(self, sent_id):
        start_end_pos = self.sentence_map.get(sent_id)
        if start_end_pos is None:
            return None
        start_pos, end_pos = start_end_pos
        return self.dataset[start_pos:end_pos]

    def __contains__(self, item):
        return item in self.sentence_map
