import logging
import os
import re

import numpy as np
import torch

from few_shot_ner.alphabet import Alphabet


# Regular expressions used to normalize digits.
DIGIT_RE = re.compile(r"\d")
SPACE_RE = re.compile("[ ]+")
PUNC_RE = re.compile("[^0-9a-zA-Z_-]")
PAD = "_PAD"
PAD_SLOT = "_PAD_SLOT"


def token_normalizer(token, lower=True, normalize_digits=True):
    if lower:
        token = token.lower()
    if normalize_digits:
        token = DIGIT_RE.sub("0", token)
    return token


def tsv_tokenizer(tsv):
    fields = tsv.split("\t")
    if not(len(fields) >= 3):
        raise ValueError("'{}' has less than 3 columns ({:d})".format(tsv, len(fields)))
    domain = fields[0].strip()
    intent = fields[1].strip()
    token_slots = fields[2].strip().split(' ')
    slots = []
    tokens = []
    for token_slot in token_slots:
        token_slot_array = token_slot.strip().split('|')
        if len(token_slot_array) != 2:
            raise ValueError("Utterance '{}' : '{}' has {:d} slot(s) instead of 1"
                             .format(tsv, token_slot, len(token_slot_array) - 1))
        tokens.append(token_slot_array[0].strip().lower())
        slots.append(token_slot_array[1].strip())
    return tokens, slots, intent, domain


def pad_list(arr, target_length, padding_value):
    if len(arr) >= target_length:
        return arr[:target_length]
    else:
        return arr + [padding_value] * (target_length - len(arr))


class MetaDataSet(object):
    train = "train"
    test = "test"
    dev = "dev"
    ic = "ic"
    ner = "ner"
    dc = "dc"
    _logger = logging.getLogger()

    def __init__(self,
                 alphabet_path,
                 data_path,
                 partition,
                 tokenizer,
                 max_token_length,
                 fields=frozenset(),
                 device=torch.device('cpu')):
        self._logger.info("Loading {} set".format(partition))
        self.partition = partition
        self.tokenizer = tokenizer
        self.max_token_length = max_token_length
        self.fields = fields
        self._ner_alphabet, self._ic_alphabet, self._dc_alphabet = None, None, None
        if partition == self.train:
            self._read_alphabets(alphabet_path)
        elif partition == self.dev:
            self._read_alphabets(alphabet_path, freeze=True)
        elif partition == self.test:
            self._read_alphabets(alphabet_path, freeze=True)
        else:
            raise ValueError("Unknown partition type '{}', must be one of '{}', '{}', '{}'"
                             .format(partition, self.train, self.test, self.dev))

        self._data = None

        if partition == self.test:
            self._read_utterances(data_path, device)
        else:
            self._read_annotated_utterances(data_path, device)

        self._total_size = sum(self.class_sizes)
        self._logger.info("Num data: {}".format(self._total_size))
        if self.ic in self.fields:
            self._logger.info("INTENT Alphabet Size: %d" % self._ic_alphabet.size())
        if self.ner in self.fields:
            self._logger.info("NER Alphabet Size: %d" % self._ner_alphabet.size())

        if partition == self.train:
            self._save_alphabets(alphabet_path)
            self._logger.info("Alphabets saved to {}".format(alphabet_path))

    def __len__(self):
        return self._total_size

    def reverse_ner(self, ner_predictions, length):
        batch_size, _, kbest = ner_predictions.size()
        return [[[self._ner_alphabet.get_instance(idx)
                  for idx in ner_predictions[b, :length[b], k].tolist()]
                 for k in range(kbest)]
                for b in range(batch_size)]

    def reverse_intent(self, ic_predictions):
        return [self._ic_alphabet.get_instance(idx) for idx in ic_predictions.tolist()]

    def get_all_intents(self):
        return [self._ic_alphabet.get_instance(idx) for idx in range(self.num_intents())]

    def get_all_domains(self):
        return [self._dc_alphabet.get_instance(idx) for idx in range(self.num_domains())]

    def num_slots(self):
        return self._ner_alphabet.size()

    def num_intents(self):
        return self._ic_alphabet.size()

    def num_domains(self):
        return self._dc_alphabet.size()

    def get_intent_index(self, intent):
        try:
            intent_id = self._ic_alphabet.get_index(intent)
        except KeyError:
            intent_id = self.num_intents()
        return intent_id

    def get_slot_index(self, slot):
        try:
            slot_id = self._ner_alphabet.get_index(slot)
        except KeyError:
            slot_id = self.num_slots()
        return slot_id

    def _read_alphabets(self, alphabet_path, freeze=False):
        if self.ner in self.fields:
            self._ner_alphabet = Alphabet(self.ner)
        if self.ic in self.fields:
            self._ic_alphabet = Alphabet(self.ic)
        if self.dc in self.fields:
            self._dc_alphabet = Alphabet(self.dc)

        if os.path.isdir(alphabet_path):
            self._logger.info("Reading alphabet from {}".format(alphabet_path))
            if self.ner in self.fields:
                self._ner_alphabet.load(alphabet_path)
            if self.ic in self.fields:
                self._ic_alphabet.load(alphabet_path)
            if self.dc in self.fields:
                self._dc_alphabet.load(alphabet_path)
        else:
            if self.ner in self.fields:
                self._ner_alphabet.add(PAD_SLOT)

        if freeze:
            if self.ner in self.fields:
                self._ner_alphabet.close()
            if self.ic in self.fields:
                self._ic_alphabet.close()
            if self.dc in self.fields:
                self._dc_alphabet.close()
        else:
            if self.ner in self.fields:
                self._ner_alphabet.open()
            if self.ic in self.fields:
                self._ic_alphabet.open()
            if self.dc in self.fields:
                self._dc_alphabet.open()

    def _read_annotated_utterances(self, data_path, device):
        self.all_data = []
        # Step 1: load the data and store per utterance feature group
        with open(data_path, 'rt', encoding="utf-8") as fp:
            for line in fp:
                line = line.strip()
                if len(line):
                    #tokens, slots, intent, domain = tsv_tokenizer(line)
                    domain, intent, tokens_slots = line.split('\t')
                    tokens = ' '.join([token_normalizer(token_slot.split('|')[0]) for token_slot in tokens_slots.split()])
                    slots = [token_slot.split('|')[1] for token_slot in tokens_slots.split()]
                    tokenized_text = self.tokenizer.tokenize(tokens)

                    word_indices = []
                    i = 0
                    while i < len(tokenized_text)-1:
                        if tokenized_text[i+1][:2] != '##':
                            word_indices.append(i)
                        i += 1
                    word_indices.append(i)

                    token_ids = self.tokenizer.convert_tokens_to_ids(tokenized_text)
                    intent_id = self.get_intent_index(intent)
                    slot_ids = [self.get_slot_index(slot) for slot in slots]

                    # pad data
                    length = len(token_ids)
                    token_ids = pad_list(token_ids, self.max_token_length, self.tokenizer.pad_token_id)
                    slot_ids = pad_list(slot_ids, self.max_token_length, self._ner_alphabet.instance2index['_PAD_SLOT'])

                    sub_word_index = pad_list(word_indices, self.max_token_length, -1)
                    sub_word_valid_length = len(word_indices)
                    frame = (token_ids, length, sub_word_index, sub_word_valid_length, intent_id, slot_ids)
                    self.all_data.append(frame)

        self.bucketize_by_slot(device)

    def bucketize_by_slot(self, device):
        data = [[] for i in range(self.num_slots())]
        for frame in self.all_data:
            token_ids, length, sub_word_index, sub_word_valid_length, intent_id, slot_ids = frame
            slot_ids_set = list(x for x in set(slot_ids) if x not in (0, 1)) # assuming 'other' id is 1 and pad id is 0
            if self.partition == self.train:
                [data[slot_id].append(frame) for slot_id in slot_ids_set]
            else:
                chosen_class_id = np.random.choice(slot_ids_set)
                data[chosen_class_id].append(frame)

        self.class_sizes = [len(dat) for dat in data]
        self._data = []

        for class_data in data:
            token_inputs = torch.tensor([x[0] for x in class_data], dtype=torch.int64).to(device)
            lengths = torch.tensor([x[1] for x in class_data], dtype=torch.int64).to(device)
            sub_word_ids = torch.tensor([x[2] for x in class_data], dtype=torch.int64).to(device)
            sub_word_valid_lengths = torch.tensor([x[3] for x in class_data], dtype=torch.int64).to(device)
            intents = torch.tensor([x[4] for x in class_data], dtype=torch.int64).to(device)
            slots = torch.tensor([x[5] for x in class_data], dtype=torch.int64).to(device)

            mask = (torch.arange(self.max_token_length).to(device) < lengths.unsqueeze(1)).float()
            tensor_frame = [token_inputs, lengths, mask, sub_word_ids, sub_word_valid_lengths, intents, slots]

            self._data.append(tensor_frame)

        self.eligible_slots = [i for i, size in enumerate(self.class_sizes) if size > 0]

    def sample_task(self, batch_size, num_slots_in_task=4, num_supports=10, num_queries=None):
        """
        Return a dataset
        :param batch_size:
        :param num_slots_in_task:
        :param num_supports:
        :param num_queries:
        :return:
        """
        tgt_slots = np.random.choice(self.eligible_slots, size=num_slots_in_task, replace=False)
        queries_tokens, supports_tokens = [], []
        queries_lengths, supports_lengths = [], []
        queries_mask, supports_mask= [], []
        queries_sub_word_ids, supports_sub_word_ids = [], []
        queries_sub_word_valid_lengths, supports_sub_word_valid_lengths = [], []
        queries_intents, supports_intents = [], []
        queries_slots, supports_slots= [], []

        for i in tgt_slots:
            size = self.class_sizes[i]
            if num_queries is None:
                n_queries = size - num_supports
            else:
                n_queries = min(size - num_supports, num_queries)
            n_examples = num_supports + n_queries

            idx = torch.randperm(size)[:n_examples]
            queries_idx, supports_idx = idx[:n_queries], idx[n_queries:]

            queries_tokens.append(self._data[i][0][queries_idx])
            queries_lengths.append(self._data[i][1][queries_idx])
            queries_mask.append(self._data[i][2][queries_idx])
            queries_sub_word_ids.append(self._data[i][3][queries_idx])
            queries_sub_word_valid_lengths.append(self._data[i][4][queries_idx])
            queries_intents.append(self._data[i][5][queries_idx])
            queries_slots.append(self._data[i][6][queries_idx])

            supports_tokens.append(self._data[i][0][supports_idx])
            supports_lengths.append(self._data[i][1][supports_idx])
            supports_mask.append(self._data[i][2][supports_idx])
            supports_sub_word_ids.append(self._data[i][3][supports_idx])
            supports_sub_word_valid_lengths.append(self._data[i][4][supports_idx])
            supports_intents.append(self._data[i][5][supports_idx])
            supports_slots.append(self._data[i][6][supports_idx])

        queries_tokens = torch.cat(queries_tokens, 0)
        queries_lengths = torch.cat(queries_lengths, 0)
        queries_mask= torch.cat(queries_mask, 0)
        queries_sub_word_ids = torch.cat(queries_sub_word_ids, 0)
        queries_sub_word_valid_lengths = torch.cat(queries_sub_word_valid_lengths, 0)
        queries_intents = torch.cat(queries_intents, 0)
        queries_slots = torch.cat(queries_slots, 0)

        supports_tokens = torch.cat(supports_tokens, 0)
        supports_lengths = torch.cat(supports_lengths, 0)
        supports_mask = torch.cat(supports_mask, 0)
        supports_sub_word_ids = torch.cat(supports_sub_word_ids, 0)
        supports_sub_word_valid_lengths = torch.cat(supports_sub_word_valid_lengths, 0)
        supports_intents = torch.cat(supports_intents, 0)
        supports_slots = torch.cat(supports_slots, 0)

        slots = np.append(tgt_slots, [0, 1])
        slots = torch.Tensor(slots).unsqueeze(1).unsqueeze(2).long()
        keep_mask = (queries_slots == slots).sum(0)
        queries_slots = queries_slots * keep_mask + torch.ones_like(queries_slots) * (1-keep_mask)
        keep_mask = (supports_slots == slots).sum(0)
        supports_slots = supports_slots * keep_mask + torch.ones_like(supports_slots) * (1-keep_mask)

        supports = (supports_tokens, supports_lengths, supports_mask, supports_sub_word_ids,
                supports_sub_word_valid_lengths, supports_intents, supports_slots)
        queries = (queries_tokens, queries_lengths, queries_mask, queries_sub_word_ids,
                queries_sub_word_valid_lengths, queries_intents, queries_slots)
        return DataSet(tgt_slots, supports, queries, batch_size)

    def _save_alphabets(self, alphabet_path):
        if self.ner in self.fields:
            self._ner_alphabet.save(alphabet_path)
        if self.ic in self.fields:
            self._ic_alphabet.save(alphabet_path)
        if self.dc in self.fields:
            self._dc_alphabet.save(alphabet_path)


class DataSet(object):
    def __init__(self, tgt_slots, supports, queries, batch_size):
        self.tgt_slots = tgt_slots
        self.supports = supports
        self.queries = queries
        self.batch_size = batch_size

    def batch_of_queries(self, device):
        num_queries = self.queries[0].size(0)
        batch_size = min(num_queries, self.batch_size)
        idx = torch.randperm(num_queries)[:batch_size]
        return (
            (
                self.queries[0][idx].to(device),
                self.queries[1][idx].to(device),
                self.queries[2][idx].to(device),
                self.queries[3][idx].to(device),
                self.queries[4][idx].to(device)
            ),
            self.queries[5][idx].to(device),
            self.queries[6][idx].to(device)
        )

    def batch_of_supports(self, device):
        num_supports = self.supports[0].size(0)
        batch_size = min(num_supports, self.batch_size)
        idx = torch.randperm(num_supports)[:batch_size]
        return (
            (
                self.supports[0][idx].to(device),
                self.supports[1][idx].to(device),
                self.supports[2][idx].to(device),
                self.supports[3][idx].to(device),
                self.supports[4][idx].to(device)
            ),
            self.supports[5][idx].to(device),
            self.supports[6][idx].to(device)
        )

    def batch(self, device):
        data = tuple(map(torch.cat, zip(self.queries, self.supports)))
        num_data = data[0].size(0)
        batch_size = min(num_data, self.batch_size)
        idx = torch.randperm(num_data)[:batch_size]
        return (
            (
                data[0][idx].to(device),
                data[1][idx].to(device),
                data[2][idx].to(device),
                data[3][idx].to(device),
                data[4][idx].to(device)
            ),
            data[5][idx].to(device),
            data[6][idx].to(device)
        )

    def get_supports(self, device):
        return (
            (
                self.supports[0].to(device),
                self.supports[1].to(device),
                self.supports[2].to(device),
                self.supports[3].to(device),
                self.supports[4].to(device)
            ),
            self.supports[5].to(device),
            self.supports[6].to(device)
        )

    def iter_queries(self, device):
        num_queries = self.queries[0].size(0)
        batch_size = min(num_queries, self.batch_size)
        for start_idx in range(0, num_queries, batch_size):
            excerpt = slice(start_idx, min(start_idx + batch_size, num_queries))
            tokens, lengths, mask, sub_word_ids, sub_word_valid_lengths, intents, slots = \
                tuple(d[excerpt].to(device) for d in self.queries)
            yield (
                (
                    tokens,
                    lengths,
                    mask,
                    sub_word_ids,
                    sub_word_valid_lengths
                ),
                intents,
                slots
            )
