
import logging

import numpy as np
import itertools

from copy import deepcopy

import torch
from torch.utils.data import TensorDataset

from .datasets_factory import DATASET_CLASS

LOGGER = logging.getLogger(__name__)


DATA_HANDLER_INDEX = 0

class MetaHandler():
    """
    For the moment, we assume only one shared LM head across any datasets.
    All token and sequence classification heads are separate.
    The correct numbers need to be set in the external script
    (could later modify so that the correct number of heads are inferred)

    """
    def __init__(self, nsp_args):
        self.features = []
        self.examples = []
        self.__num_labels_cls = []
        self.__num_labels_tok = []
        self.num_datahandlers = len(nsp_args.data_set)
        self.dataset_handlers = []
        self.nsp_args = deepcopy(nsp_args)
        self.train_files = self.nsp_args.train_file
        self.predict_files = self.nsp_args.predict_file
        self.valid_golds = self.nsp_args.valid_gold
        for i, dataset in enumerate(nsp_args.data_set):
            self.nsp_args.train_file = self.train_files[i]
            self.nsp_args.predict_file = self.predict_files[i]
            self.nsp_args.valid_gold = self.valid_golds[i]
            new_dataset = DATASET_CLASS[dataset](self.nsp_args)
            new_dataset.datahandler_index = i
            # for parallel data_handler
            if hasattr(new_dataset, 'mask_file') and len(self.nsp_args.dataset_parallel_mask_file) > i:
                new_dataset.mask_file = self.nsp_args.dataset_parallel_mask_file[i]
            self.dataset_handlers.append(new_dataset)
            self.__num_labels_cls.append(new_dataset.num_labels_cls)
            self.__num_labels_tok.append(new_dataset.num_labels_tok)

        # self.map_to_datasets = [100, 150, 300] means
        # - len(self.dataset_handlers[0].examples) = 100
        # - len(self.dataset_handlers[1].examples) = 50
        # - len(self.dataset_handlers[2].examples) = 150
        #self.map_to_datasets = []
        self.index_to_map = 0
        #self._set_up_functions(0)  # set up all relevant function to the first data_handler

        self._train_dataloader = None
        self._eval_dataloader = None

        self.merge_cls_heads = True # if true, merge cls labels across data_handlers, else stack heads
        self.merge_tok_heads = True # if true, merge tok labels across data_handlers, else stack heads
        # need to be adopted after merging heads!
        self.plus_generation = nsp_args.plus_generation
        self.plus_classify_tokens = nsp_args.plus_classify_tokens
        self.plus_classify_sequence = nsp_args.plus_classify_sequence
        self.map_to_cls_index = 0 # index if multiple cls heads exist
        self.map_to_tok_index = 0 # index if multiple tok heads exist
        self.map_to_gen_index = 0 # index if multiple gen heads exist
        self.datahandler_index = -1

    def __str__(self):
        return self.__repr__()

    def __repr__(self):
        return self.__class__.__name__

    @property
    def data_handler(self):
        #current data_handler instance at index pointed by self.index_to_map
        return self.dataset_handlers[self.index_to_map]

    @property
    def tokenizer(self):
        return self.data_handler.tokenizer

    @tokenizer.setter
    def tokenizer(self, tokenizer):
        # set tokenizer for all data_handlers
        if isinstance(tokenizer, list):
            assert len(tokenizer) == len(self.dataset_handlers)
            for data_handler, tok in zip(self.dataset_handlers, tokenizer):
                data_handler.tokenizer  = tok
        else:
            for data_handler in self.dataset_handlers:
                data_handler.tokenizer  = tokenizer

    @property
    def truncate_end(self):
        return self.data_handler.truncate_end

    @truncate_end.setter
    def truncate_end(self, truncate_end):
        for data_handler in self.dataset_handlers:
            data_handler.truncate_end  = truncate_end

    @property
    def truncation_strategy(self):
        return self.data_handler.truncation_strategy

    @truncation_strategy.setter
    def truncation_strategy(self, truncation_strategy):
        for data_handler in self.dataset_handlers:
            data_handler.truncation_strategy  = truncation_strategy

    @property
    def max_seq_length(self):
        return max([data_handler.max_seq_length for data_handler in self.dataset_handlers])

    @max_seq_length.setter
    def max_seq_length(self, max_seq_length):
        for data_handler in self.dataset_handlers:
            data_handler.max_seq_length  = max_seq_length

    @property
    def train_dataloader(self):
        return self._train_dataloader

    @train_dataloader.setter
    def train_dataloader(self, dataloader):
        self._train_dataloader = dataloader

    @property
    def eval_dataloader(self):
        return self._eval_dataloader

    @eval_dataloader.setter
    def eval_dataloader(self, dataloader):
        self._eval_dataloader = dataloader

    @property
    def num_invalid(self):
        return sum([data_handler.num_invalid for data_handler in self.dataset_handlers])

    @property
    def trunc_part_a(self):
        return sum([data_handler.trunc_part_a for data_handler in self.dataset_handlers])

    @property
    def trunc_part_b(self):
        return sum([data_handler.trunc_part_b for data_handler in self.dataset_handlers])

    @property
    def max_a(self):
        return max([data_handler.max_a for data_handler in self.dataset_handlers])

    @property
    def max_b(self):
        return max([data_handler.max_b for data_handler in self.dataset_handlers])

    @property
    def num_labels_cls(self):
        # ex)
        # data_handler_A.num_labels_cls = [3, 2]
        # data_handler_B.num_labels_cls = [2, 4]
        if self.merge_cls_heads:
            # => self.num_labels_cls = [3, 4] (= column-wise max)
            #return [max(cls) for cls in zip(self.__num_labels_cls)]
            return list(np.array(self.__num_labels_cls).max(axis=0))
        else:
            # => self.num_labels_cls = [3, 2, 2, 4] (= flattened list of [[3, 2], [2, 4]])
            return list(itertools.chain.from_iterable(self.__num_labels_cls))

    @property
    def num_labels_tok(self):
        # ex)
        # data_handler_A.num_labels_cls = [3, 2]
        # data_handler_B.num_labels_cls = [2, 4]
        if self.merge_tok_heads:
            # => self.num_labels_tok = [3, 4] (= column-wise max)
            #return [max(cls) for cls in zip(self.__num_labels_tok)]
            return list(np.array(self.__num_labels_tok).max(axis=0))
        else:
            # => self.num_labels_tok = [3, 2, 2, 4] (= flattened list of [[3, 2], [2, 4]])
            return list(itertools.chain.from_iterable(self.__num_labels_tok))

    @property
    def write_predictions(self):
        return self.data_handler.write_predictions

    @property
    def write_tok_predictions(self):
        return self.data_handler.write_tok_predictions

    @property
    def write_cls_predictions(self):
        return self.data_handler.write_cls_predictions

    @property
    def write_eval(self):
        return self.data_handler.write_eval

    def read_examples(self, is_training=False):
        for i, data_handler in enumerate(self.dataset_handlers):
            assert data_handler.tokenizer is not None
            assert i == data_handler.datahandler_index
            data_handler.initialize() # reset list, counters, etc.
            data_handler.read_examples(is_training)
            self.examples += data_handler.examples
            #self.map_to_datasets.append(len(self.examples))
            #self.__num_labels_cls[i] = data_handler.num_labels_cls
            #self.__num_labels_tok[i] = data_handler.num_labels_tok
        LOGGER.info("Total number of examples: %s", len(self.examples))

    def get_token_classification_ids(self, current_example, input_ids):
        self.index_to_map = current_example.dataset_index
        return self.data_handler.get_token_classification_ids(current_example, input_ids)

    def possible_mask_locations(self, part_a, part_b, is_training=False, example=None):
        if example is not None:
            self.index_to_map = example.dataset_index
        return self.data_handler.possible_mask_locations(
            part_a, part_b, is_training=is_training, example=example)

    def validate_example(self, example, is_training=False):
        self.index_to_map = example.dataset_index
        return self.data_handler.validate_example(example, is_training=is_training)

    def apply_tokenization(self, example, is_training=False):
        self.index_to_map = example.dataset_index
        return self.data_handler.apply_tokenization(example, is_training=is_training)

    def arrange_classify_output(self, current_example, max_classify_index):
        self.index_to_map = current_example.dataset_index # check belonging data_handler
        arranged_output = None
        try:
            arranged_output = self.data_handler.arrange_classify_output(current_example, max_classify_index)
        except NotImplementedError as e: ## can happen
            LOGGER.debug('{} Skip.\n{}'.format(e, current_example))
        return arranged_output

    def arrange_token_classify_output(self, current_example, classification_tokens, input_ids):
        self.index_to_map = current_example.dataset_index # check belonging data_handler
        arranged_output = None
        try:
            arranged_output = self.data_handler.arrange_token_classify_output(current_example, classification_tokens, input_ids)
        except NotImplementedError as e: ## can happen
            LOGGER.debug('{} Skip.\n{}'.format(e, current_example))
        return arranged_output

    def arrange_generated_output(self, current_example, generated_ids, input_ids):
        self.index_to_map = current_example.dataset_index # check belonging data_handler
        arranged_output = None
        try:
            arranged_output = self.data_handler.arrange_generated_output(current_example, generated_ids, input_ids)
        except NotImplementedError as e: ## can happen
            LOGGER.debug('{} Skip.\n{}'.format(e, current_example))
        return arranged_output

    def combine_classification_tokens(self, classification_tokens):
        raise NotImplementedError('combine_classification_tokens() is not implemented.')

    def evaluate(self, output_prediction_file, mode):
        results = None
        try:
            results = self.data_handler.evaluate(output_prediction_file, mode)
        except NotImplementedError as e: ## shouldn't happen....
            LOGGER.debug('{} Skip. {}'.format(e, output_prediction_file))
        return results

    def select_deciding_score(self, results_collection):
        results = None
        try:
            results = self.data_handler.select_deciding_score(results_collection)
        except NotImplementedError as e:
            LOGGER.debug('{} Skip. {}'.format(e))
        return results

    def select_scores_to_plot(self, results_collection):
        results = None
        try:
            results = self.data_handler.select_scores_to_plot(results_collection)
        except NotImplementedError as e:
            LOGGER.debug('{} Skip. {}'.format(e))
        return results

    #def _set_up_functions(self, index):
    #    # !!! make sure to call _set_up_functions() everytime you change the pointer to current data_handler
    #    self.truncate_end = self.dataset_handlers[index].truncate_end
    #    self.truncation_strategy = self.dataset_handlers[index].truncation_strategy
    #    self.max_seq_length = self.dataset_handlers[index].max_seq_length
    #    self.get_token_classification_ids = \
    #        self.dataset_handlers[index].get_token_classification_ids
    #    self.possible_mask_locations = \
    #        self.dataset_handlers[index].possible_mask_locations
    #    self.validate_example = \
    #        self.dataset_handlers[index].validate_example
    #    pass

    #def update_info(self, index):
    #    # !!! make sure that self.index_to_map was reset to zero before loop over the self.examples list
    #    # important, for example, when calling an arbitrary data handler's func from unittest
    #    if self.index_to_map < len(self.map_to_datasets):  # else we are on the last dataset handler
    #        if index == self.map_to_datasets[self.index_to_map]:
    #            #self._set_up_functions(self.index_to_map)
    #            self.index_to_map += 1

    #def get_dataset_handler(self, index):
    #    """
    #    Given an index, figure out to which dataset handler the example belongs to.
    #
    #    :param index: an index pointing to an element in either self.examples or self.features
    #    :return: correct element of self.dataset_handlers so that the dataset hanlder that belongs
    #            to the example is returned.
    #    """
    #    # use self.map_to_datasets[]
    #    raise NotImplementedError

    def create_tensor_dataset(self):
        """
        Using a data_handler, whose features have been filled via the function
        convert_examples_to_features from a subclass instance of :py:class:Masking,
        convert the features into a TensorDataset

        :return: the features represented as a TensorDataset
        """
        """
        def _dataset_index_to_cls_offset(dataset_index):
            # data_handler_A.num_labels_cls = [3, 2, 3] # A = dataset index = 0
            # data_handler_B.num_labels_cls = [2, 4]    # B = dataset index = 1
            # self.num_labels_cls = [3, 2, 3, 2, 4]
            # then, _dataset_index_to_cls_offset(0) = [0, 1, 2] <= pointer to the index in num_labels_cls
            # then, _dataset_index_to_cls_offset(1) = [3, 4]    <= pointer to the index in num_labels_cls
            return range(len(self.__num_labels_cls[:dataset_index]), len(self.__num_labels_cls[dataset_index]))

        def _merge_cls_labels(map_labels, classify_id_cls):
            return

        dataset_index = [example.dataset_index for example in self.examples]
        map_labels = {}

        for index_to_map, feature in zip(dataset_index, self.features):
            if self.merge_cls_labels:
                cls_labels = _merge_cls_labels(map_labels, feature.classify_id_cls)
            else: # stack heads
                cls_offset = _dataset_index_to_cls_offset(index_to_map)
                assert len(cls_offset) == len(feature.classify_id_cls)
                # self.num_labels_cls = [3, 2, 3, 2, 4]
                # index_to_map = 1 and _dataset_index_to_cls_offset(1) = [3, 4]
                # feature.classify_id_cls = [0, 3]
                # cls_labels = [-1, -1, -1, 0, 3]
                cls_labels = [[-1] for _ in range(len(self.num_labels_cls))]
                for i, offset in enumerate(cls_offset):
                    cls_labels[offset] = feature.classify_id_cls[i]


        """
        LOGGER.info("self.features: %s" % self.features)
        all_input_ids = torch.tensor([f.input_ids for f in self.features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in self.features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in self.features], dtype=torch.long)
        all_gen_label_ids = torch.tensor([f.gen_label_ids for f in self.features], dtype=torch.long)
        all_classify_ids_cls = torch.tensor([f.classify_id_cls for f in self.features])
        all_classify_ids_tokens = torch.tensor([f.classify_id_tokens for f in self.features])
        all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
        data_set = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_gen_label_ids,
                                 all_classify_ids_cls, all_classify_ids_tokens, all_example_index)
        return data_set
        """
        # pylint: disable=not-callable
        # pylint: disable=no-member
        all_input_ids = torch.tensor([f.input_ids for f in self.features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in self.features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in self.features],
                                       dtype=torch.long)
        all_gen_label_ids = torch.tensor([f.gen_label_ids for f in self.features],
                                         dtype=torch.long)

        # Assemble sequence classification ids
        # each example and its feature has to have the full range of heads but if a head is not
        # applicable for a dataset, then it write -1 in its place
        arranged_classify_id_cls = [[] for _ in range(len(self.num_labels_cls))]
        local_index_to_map = 0
        for index, f in enumerate(self.features):
            if local_index_to_map < len(self.map_to_datasets):  # else we are on the last dataset handler
                if index == self.map_to_datasets[local_index_to_map]:
                    local_index_to_map += 1
            for inner_index in range(len(self.num_labels_cls)):
                if inner_index == (local_index_to_map - 1):
                    arranged_classify_id_cls[inner_index] += f.classify_id_cls
                else:
                    arranged_classify_id_cls[inner_index] += [-1]
        # need to transform into example centric view
        arranged_classify_id_cls = np.array(arranged_classify_id_cls)
        arranged_classify_id_cls = np.swapaxes(arranged_classify_id_cls, 0, 1)
        all_classify_ids_cls = torch.tensor(arranged_classify_id_cls)

        # Assemble token classification ids
        # each example and its feature has to have the full range of heads but if a head is not
        # applicable for a dataset, then it write -1 in its place
        arranged_classify_ids_tokens = [[] for _ in range(len(self.num_labels_tok))]
        local_index_to_map = 0
        for index, f in enumerate(self.features):
            if local_index_to_map < len(self.map_to_datasets):  # else we are on the last dataset handler
                if index == self.map_to_datasets[local_index_to_map]:
                    local_index_to_map += 1
            for inner_index in range(len(self.num_labels_tok)):
                if inner_index == (local_index_to_map - 1):
                    arranged_classify_ids_tokens[inner_index] += f.classify_id_tokens
                else:
                    arranged_classify_ids_tokens[inner_index] += [[-1] * len(f.input_ids)]
        # need to transform into example centric view
        arranged_classify_ids_tokens = np.array(arranged_classify_ids_tokens)
        arranged_classify_ids_tokens = np.swapaxes(arranged_classify_ids_tokens, 0, 1)
        all_classify_ids_tokens = torch.tensor(arranged_classify_ids_tokens)

        all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
        #TODO find out how we can control the order of TensorDataset
        #is this possible? or do we need to implement a custom class? If so, this class should
        #inherit from TensorDataset.
        data_set = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_gen_label_ids,
                                 all_classify_ids_cls, all_classify_ids_tokens, all_example_index)
        return data_set
    """