import numpy as np
import torch
from collections import defaultdict
from torch.utils.data import Dataset
from classification.util.preprocessing import draw_matrix


class IntentDataset(Dataset):

    def __init__(self, df, tokenizer, intents, descriptions, concepts, actions, description_first,
                 uttr_len, desc_len, similarity_matrix=None, sampling_strategy='intents', neg_k=None,
                 infer_intents=None):
        self.intents = intents
        self.uttrs = [tokenizer(text, padding='max_length', truncation=True, max_length=uttr_len)
                      for text in df.text.values]
        self.descriptions = [tokenizer(descriptions[intent], padding='max_length', truncation=True, max_length=desc_len)
                             for intent in intents]
        self.labels = [intents.index(intent) for intent in df.intents.values]

        self.intent_concept = [concepts[intent] for intent in intents]
        self.intent_action = [actions[intent] for intent in intents]
        self.label_concepts = [concepts[intent] for intent in df.intents.values]
        self.label_actions = [actions[intent] for intent in df.intents.values]
        self.num_intents = len(intents)
        self.num_uttrs = len(self.uttrs)
        self.tokenizer = tokenizer
        self.neg_k = neg_k
        self.infer_intents = infer_intents
        if infer_intents:
            self.infer_intents_ids = [intents.index(intent) for intent in infer_intents]
        self.description_first = description_first
        self.sampling_strategy = sampling_strategy
        if sampling_strategy == 'intents':
            if similarity_matrix is None:
                similarity_matrix = np.eye(self.num_intents)
            self.intent_similarity = self.additive_smoothing(similarity_matrix, self.num_intents)
            # draw_matrix(self.intent_similarity, intents)
        elif sampling_strategy == 'utterances':
            self.similar_uttrs = similarity_matrix
        else:
            raise ValueError('Unexpected sampling strategy')

    def __len__(self):
        return self.num_uttrs

    def encode_nli_pair(self, uttr_idx, intent_idx):
        first = self.descriptions[intent_idx] if self.description_first else self.uttrs[uttr_idx]
        second = self.uttrs[uttr_idx] if self.description_first else self.descriptions[intent_idx]
        input_ids = first['input_ids'] + [self.tokenizer.sep_token_id] + second['input_ids'][1:]
        attention_mask = first['attention_mask'] + [1] + second['attention_mask'][1:]
        return torch.LongTensor(input_ids), torch.LongTensor(attention_mask)

    def __getitem__(self, idx):
        chosen_uttrs = [idx]
        chosen_intents = np.arange(self.num_intents)
        if self.neg_k:
            if self.sampling_strategy == 'intents':
                chosen_intents = np.random.choice(
                    np.arange(self.num_intents), p=self.intent_similarity[self.labels[idx]],
                    size=self.neg_k + 1, replace=False
                )
            elif self.sampling_strategy == 'utterances':
                candidates = self.similar_uttrs[idx] or \
                             [i for i, uttr in enumerate(range(self.num_uttrs)) if self.labels[i] != self.labels[idx]]
                chosen_uttrs += np.random.choice(candidates, size=self.neg_k).tolist()
                chosen_intents = [self.labels[idx]]
        if self.infer_intents:
            chosen_intents = self.infer_intents_ids
        uttr_ids, uttr_attention_mask = zip(*[self.encode_nli_pair(u, i) for u in chosen_uttrs for i in chosen_intents])
        sample = {
            'pair_ids': torch.stack(uttr_ids),
            'pair_attention_mask': torch.stack(uttr_attention_mask),
            'label': self.labels[idx],
            'label_enc': torch.tensor([float(self.labels[u] == i) for u in chosen_uttrs for i in chosen_intents]),
            'label_concept': self.label_concepts[idx],
            'label_concept_enc': torch.tensor([float(self.label_concepts[u] == self.intent_concept[i])
                                               for u in chosen_uttrs for i in chosen_intents]),
            'label_action': self.label_actions[idx],
            'label_action_enc': torch.tensor([float(self.label_actions[u] == self.intent_action[i])
                                              for u in chosen_uttrs for i in chosen_intents])
        }
        return sample

    @staticmethod
    def additive_smoothing(matrix, n_categories, alpha=0.001):
        for i in range(matrix.shape[0]):
            n = matrix[i].sum()
            matrix[i] = (matrix[i] + alpha) / (n + alpha * n_categories)
        return matrix


# class IntentDatasetForContrastiveLearning(IntentDataset):
#     def __init__(self, df, tokenizer, intents, descriptions, concepts, actions, description_first,
#                  uttr_len, desc_len, intent_similarity=None, neg_k=None):
#         super().__init__(df, tokenizer, intents, descriptions, concepts, actions, description_first,
#                          uttr_len, desc_len, intent_similarity=intent_similarity, neg_k=neg_k)
#         self.uttrs_by_intent = defaultdict(list)
#         for i in range(len(self.labels)):
#             self.uttrs_by_intent[self.labels[i]].append(i)
#
#     def __getitem__(self, idx):
#
#         if self.neg_k:
#             assert self.neg_k is not None
#             sampled_intents = np.random.choice(
#                 np.arange(self.num_intents), p=self.intent_similarity[self.labels[idx]], size=self.neg_k + 1,
#                 replace=False
#             )
#             chosen_uttrs = []
#             chosen_intents = []
#             for intent in sampled_intents:
#                 uttr = idx
#                 if intent != self.labels[idx]:
#                     if len(self.uttrs_by_intent[intent]) > 0:
#                         uttr = np.random.choice(self.uttrs_by_intent[intent])
#                 chosen_uttrs.append(uttr)
#                 chosen_intents.append(intent)
#             return {
#                 'pairs': torch.stack([
#                    torch.stack([
#                         self.encode_nli_pair(uttr_idx, intent_idx) for uttr_idx in chosen_uttrs
#                     ]) for intent_idx in chosen_intents
#                 ]),
#                 'labels': torch.arange(len(chosen_intents)),
#                 'label': self.labels[idx]
#             }
#         else:
#             intents = np.arange(self.num_intents)
#             return {
#                 'uttr': torch.stack([self.encode_nli_pair(idx, i) for i in np.arange(self.num_intents)]),
#                 'label': self.labels[idx],
#                 'label_enc': torch.tensor([float(self.labels[idx] == i) for i in intents]),
#                 'label_concept': self.label_concepts[idx],
#                 'label_concept_enc': torch.tensor([float(self.label_concepts[idx] == self.intent_concept[i])
#                                                    for i in intents]),
#                 'label_action': self.label_actions[idx],
#                 'label_action_enc': torch.tensor([float(self.label_actions[idx] == self.intent_action[i])
#                                                   for i in intents])
#             }


def get_dataset(cfg, df, tokenizer, intents, descriptions,
                concepts, actions, similarity_matrix=None, sampling_strategy='intents', k_neg=None, infer_intents=None) -> IntentDataset:
    if cfg.model.model_type == 'nli_ca' or cfg.model.model_type == 'nli_strict':
        return IntentDataset(
            df, tokenizer, intents, descriptions, concepts, actions,
            cfg.experiment.intent_desc_first, cfg.dataset.uttr_len, cfg.dataset.desc_len,
            similarity_matrix, sampling_strategy, k_neg, infer_intents
        )
    else:
        ValueError("Unknown model type")
