import torch
import os
import pickle
import random

import torch.nn.functional as F

from tqdm import tqdm
from typing import List
from torch.utils.data import Dataset

from sentence_transformers import SentenceTransformer, util

class SelectionDataset(Dataset):
    def __init__(
        self,
        raw_dataset: List,
        tokenizer,
        setname: str,
        target_object: str,
        max_seq_len: int = 128,
        num_candidates: int = 11,
        num_hard_negs: int = 5,
        is_curriculum: bool = False,
        uttr_token: str = "[UTTR]",
        txt_save_fname: str = None,
        tensor_save_fname: str = None,
    ):
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len
        self.uttr_token = uttr_token
        self.setname = setname
        
        assert setname in ["train", "dev", "test"]

        txt_save_fname = txt_save_fname.format(setname)
        tensor_save_fname = tensor_save_fname.format(setname)

        selection_dataset = self._get_selection_dataset(
            setname=setname,
            raw_dataset=raw_dataset,
            num_candidates=num_candidates,
            num_hard_negs=num_hard_negs,
            is_curriculum=is_curriculum,
            txt_save_fname=txt_save_fname,
            target_object=target_object,
        )

        self.feature = self.featurize(
            selection_dataset=selection_dataset,
            tensor_save_fname=tensor_save_fname,
            num_candidates=num_candidates
        )

    def __len__(self):
        return len(self.feature[0])

    def __getitem__(self, idx):
        return tuple([el[idx] for el in self.feature])

    def featurize(
        self,
        selection_dataset: List,
        tensor_save_fname: str,
        num_candidates: int
    ):
        if os.path.exists(tensor_save_fname):
            print(f"{tensor_save_fname} exist!")
            with open(tensor_save_fname, "rb") as f:
                return pickle.load(f)
        
        ids_list = [[] for _ in range(num_candidates)]
        masks_list = [[] for _ in range(num_candidates)]
        labels = [0] * len(selection_dataset)
        
        for data in tqdm(selection_dataset):
            assert len(data) == 1 + num_candidates and all([isinstance(el, str) for el in data])
            context, candidates = data[0], data[1:]
            assert len(candidates) == num_candidates

            encoded = self.tokenizer(
                [context] * num_candidates,
                text_pair=candidates,
                max_length=self.max_seq_len,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            )
            encoded_ids = encoded["input_ids"]
            encoded_mask = encoded["attention_mask"]

            assert len(encoded_ids) == len(encoded_mask) == num_candidates
            
            for candi_idx in range(num_candidates):
                ids_list[candi_idx].append(encoded_ids[candi_idx])
                masks_list[candi_idx].append(encoded_mask[candi_idx])

        assert len(list(set([len(el) for el in ids_list]))) == 1
        assert len(list(set([len(el) for el in masks_list]))) == 1
        
        ids_list = [torch.stack(el) for el in ids_list]
        masks_list = [torch.stack(el) for el in masks_list]
        labels = torch.tensor(labels)

        sample = ids_list + masks_list + [labels]
        assert len(sample) == 1 + 2 * num_candidates

        with open(tensor_save_fname, "wb") as f:
            pickle.dump(sample, f)
        
        return sample

    def _get_selection_dataset(
        self,
        setname: str,
        raw_dataset: List,
        num_candidates: int,
        num_hard_negs: int,
        is_curriculum: bool,
        txt_save_fname: str,
        target_object: str,
    ):
        print("Selection filename: {}".format(txt_save_fname))
        if os.path.exists(txt_save_fname):
            print(f"{txt_save_fname} exist!")
            with open(txt_save_fname, "rb") as f:
                return pickle.load(f)

        if target_object == "semi_hard":
            selection_dataset = self._make_selection_dataset_semihard(raw_dataset, num_candidates, num_hard_negs)
        else:
            selection_dataset = self._make_selection_dataset(setname, raw_dataset, target_object, num_candidates, num_hard_negs, is_curriculum)

        os.makedirs(os.path.dirname(txt_save_fname), exist_ok=True)

        with open(txt_save_fname, "wb") as f:
            pickle.dump(selection_dataset, f)

        return selection_dataset

    def _make_selection_dataset(
        self,
        setname: str,
        raw_dataset: List,
        target_object: str,
        num_candidates: int,
        num_hard_negs: int,
        is_curriculum: bool,
    ):
        selection_dataset = []

        if is_curriculum:
            random.shuffle(raw_dataset)

        for idx, conv in enumerate(tqdm(raw_dataset)):
            
            if is_curriculum:
                num_hard_negs = min(5, idx * 6 // len(raw_dataset))
            
            selection_data = []
            selection_data.append(self.uttr_token.join(conv["context"]))
            selection_data.append(conv["positive_responses"][0])

            if setname == "test":
                selection_data += conv[target_object][:num_hard_negs]
                if num_hard_negs + 1 < num_candidates:
                    selection_data += conv["random_negative_responses"][: num_candidates - num_hard_negs - 1]
                assert len(selection_data) == num_candidates + 1
            else:
                selection_data += conv[target_object][:num_hard_negs]
                selection_data += conv["random_negative_responses"]
                if num_hard_negs + 6 < num_candidates:
                    selection_data += conv["random_sampled"][: num_candidates - num_hard_negs - 6]
                assert len(selection_data) == num_candidates + 1

            selection_dataset.append(selection_data)

        return selection_dataset

    def _make_selection_dataset_semihard(
        self,
        raw_dataset: List,
        num_candidates: int,
        num_hard_negs: int
    ):
        model = SentenceTransformer("all-MiniLM-L6-v2")
        device = torch.device("cuda")
        model.to(device)

        selection_dataset = []

        for conv in tqdm(raw_dataset):
            selection_data = []
            selection_data.append("[UTTR]".join(conv["context"]))
            selection_data.append(conv["positive_responses"][0])
                
            pool = [e["random_negative_responses"] + e["random_sampled"][:num_hard_negs] for e in random.sample(raw_dataset, 31)]
            pool = [e for c in pool for e in c]
            
            assert all([isinstance(e, str) for e in pool])
            
            positive = conv["positive_responses"][0]
            context = " ".join(conv["context"])
            
            with torch.no_grad():
                embeddings1 = model.encode(context, convert_to_tensor=True)
                embeddings2 = model.encode([positive] + pool, convert_to_tensor=True)
                
                # Compute cosine-similarits
                cosine_scores = util.pytorch_cos_sim(embeddings1, embeddings2).cpu().detach()

                assert cosine_scores.shape == (1, 311)
                
                cosine_scores = cosine_scores[0]
                pos_score = cosine_scores[0]
                neg_scores = cosine_scores[1:]

                indices = torch.argsort(torch.abs(pos_score - neg_scores - 0.07), descending=False)[:num_hard_negs]
                selected = [e for i, e in enumerate(pool) if i in indices]
                
                assert len(selected) == 5

                selection_data += selected
            
            selection_data += conv["random_negative_responses"]
            assert len(conv["random_negative_responses"]) == num_hard_negs
            assert len(selection_data) == num_candidates + 1
            selection_dataset.append(selection_data)
        
        return selection_dataset
