import torch
import os
import pickle
import random

import torch.nn.functional as F

from tqdm import tqdm
from torch.utils.data import Dataset

from sentence_transformers import SentenceTransformer, util

class SelectionDataset(Dataset):
    def __init__(
        self,
        raw_dataset,
        tokenizer,
        setname: str,
        max_seq_len: int = 128,
        num_candidate: int = 11,
        uttr_token: str = "[UTTR]",
        txt_save_fname: str = None,
        tensor_save_fname: str = None,
    ):
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len
        self.uttr_token = uttr_token
        self.setname = setname
        assert setname in ["train", "dev", "test"]

        txt_save_fname, tensor_save_fname = (
            txt_save_fname.format(setname),
            tensor_save_fname.format(setname),
        )

        selection_dataset = self._get_selection_dataset(
            raw_dataset,
            num_candidate,
            txt_save_fname,
        )
        self.feature = self._tensorize_selection_dataset(
            selection_dataset, tensor_save_fname, num_candidate
        )

    def __len__(self):
        return len(self.feature[0])

    def __getitem__(self, idx):
        return tuple([el[idx] for el in self.feature])

    def _tensorize_selection_dataset(
        self, selection_dataset, tensor_save_fname, num_candidate
    ):
        if os.path.exists(tensor_save_fname):
            print(f"{tensor_save_fname} exist!")
            with open(tensor_save_fname, "rb") as f:
                return pickle.load(f)
        print("make {}".format(tensor_save_fname))
        ids_list = [[] for _ in range(num_candidate)]
        masks_list = [[] for _ in range(num_candidate)]
        labels = []
        print("data_idx_len: ", len(selection_dataset))
        print("Tensorize...")
        for _, sample in enumerate(tqdm(selection_dataset)):
            assert len(sample) == 1 + num_candidate and all(
                [isinstance(el, str) for el in sample]
            )
            context, candidates = sample[0], sample[1:]
            assert len(candidates) == num_candidate

            encoded = self.tokenizer(
                [context] * num_candidate,
                text_pair=candidates,
                max_length=self.max_seq_len,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            )
            encoded_ids, encoded_mask = (
                encoded["input_ids"],
                encoded["attention_mask"],
            )
            assert len(encoded_ids) == len(encoded_mask) == num_candidate
            for candi_idx in range(num_candidate):
                ids_list[candi_idx].append(encoded_ids[candi_idx])
                masks_list[candi_idx].append(encoded_mask[candi_idx])
            labels.append(0)

        assert len(list(set([len(el) for el in ids_list]))) == 1
        assert len(list(set([len(el) for el in masks_list]))) == 1
        ids_list = [torch.stack(el) for el in ids_list]
        masks_list = [torch.stack(el) for el in masks_list]
        labels = torch.tensor(labels)

        data = ids_list + masks_list + [labels]
        assert len(data) == 1 + 2 * num_candidate

        with open(tensor_save_fname, "wb") as f:
            pickle.dump(data, f)
        return data

    def _get_selection_dataset(
        self,
        raw_dataset,
        num_candidate,
        txt_save_fname,
    ):
        print("Selection filename: {}".format(txt_save_fname))
        if os.path.exists(txt_save_fname):
            print(f"{txt_save_fname} exist!")
            with open(txt_save_fname, "rb") as f:
                return pickle.load(f)

        selection_dataset = self._make_selection_dataset(
            raw_dataset, num_candidate
        )

        os.makedirs(os.path.dirname(txt_save_fname), exist_ok=True)
        with open(txt_save_fname, "wb") as f:
            pickle.dump(selection_dataset, f)
        return selection_dataset

    def _make_selection_dataset(self, raw_dataset, num_candidate):
        """
        :return: datset: List of [context(str), positive_response(str), negative_response_1(str), (...) negative_response_(num_candidate-1)(str)]
        """
        assert isinstance(raw_dataset, list) and all(
            [isinstance(el, list) for el in raw_dataset]
        )
        print(f"Serialized selection not exist. Make new file...")
        dataset = []
        all_responses = []
        for idx, conv in enumerate(tqdm(raw_dataset)):
            slided_conversation = self._slide_conversation(conv)
            # Check the max sequence length
            for single_conv in slided_conversation:
                assert len(single_conv) == 2 and all(
                    [isinstance(el, str) for el in single_conv]
                )
                concat_single_conv = " ".join(single_conv)
                if (
                    len(self.tokenizer.tokenize(concat_single_conv)) + 3
                    <= 128
                ):
                    dataset.append(single_conv)
            all_responses.extend([el[1] for el in slided_conversation])

        for idx, el in enumerate(dataset):
            sampled_random_negative = random.sample(
                all_responses, num_candidate
            )
            if el[1] in sampled_random_negative:
                sampled_random_negative.remove(el[1])
            sampled_random_negative = sampled_random_negative[
                : num_candidate - 1
            ]
            dataset[idx].extend(sampled_random_negative)
            assert len(dataset[idx]) == 1 + num_candidate
            assert all([isinstance(txt, str) for txt in dataset[idx]])
        return dataset

    def _slide_conversation(self, conversation):
        """
        Description:
            Return a single conversion consisting of multi-turn output into multiple "context-response" pairs by using sliding window.
        """
        assert isinstance(conversation, list) and all(
            [isinstance(el, str) for el in conversation]
        )
        pairs = []
        for idx in range(len(conversation) - 1):
            context, response = conversation[: idx + 1], conversation[idx + 1]
            pairs.append([self.uttr_token.join(context), response])
        return pairs
