import torch
from torch.utils.data import Dataset, IterableDataset


class Word2VecTrainDataset(Dataset):
    """
    Map-style dataset class for loading list of sentences and outputing center_word, context pairs
    Context words are found based on window_size
    """

    def __init__(self, sentence_list, window_size):
        self.sentence_list = sentence_list
        self.window_size = window_size
        self.samples = make_word_pairs(self.sentence_list, self.window_size)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        return sample


def make_word_pairs(sentence_list, window_size, bidirectional=True):
    """
    For each word in a sentence, find all context words within window_size and create word pairs
    If bidirectional is true, context words are taken from both left and right of center word. Otherwise, only from the right.
    """

    samples = []
    left_offset = -window_size if bidirectional else 0

    for line in sentence_list:
        for i, center_word in enumerate(line):
            for j in range(left_offset, window_size + 1):
                context_word_index = i + j
                if (
                    not context_word_index < 0
                    and not context_word_index > len(line) - 1
                    and not context_word_index == i
                ):
                    samples.append(
                        {"center_word": center_word, "target": line[context_word_index]}
                    )

    return samples


class Word2VecTrainIterableDataset(IterableDataset):
    """
    Iterable style dataset class for loading list of sentences and outputing center_word, context pairs
    Context words are found based on window_size
    """

    def __init__(
        self, data, window_size, vocab=None, shuffle=True, subsample_thresh=1e-3
    ):
        super().__init__()
        self.data = data
        self.window_size = window_size
        self.shuffle = shuffle
        self.index = None
        self.vocab = vocab
        total_words = sum(self.vocab.freqs.values())
        unigram_prob = (
            torch.tensor([self.vocab.freqs.get(key, 0) for key in self.vocab.itos])
            / total_words
        )
        self.subsampling_prob = 1.0 - torch.sqrt(
            subsample_thresh / (unigram_prob + 1e-19)
        )  #

    def __iter__(self):
        if self.shuffle:
            self.index = torch.randperm(len(self.data))
        else:
            self.index = torch.arange(len(self.data))

        for idx in self.index:
            yield from self.make_word_pairs(
                torch.tensor(self.data[idx]), self.window_size
            )

    def make_word_pairs(self, sentence, window_size, bidirectional=True):
        left_offset = -window_size if bidirectional else 0
        sentence = self.mask_sentence(sentence)
        for i, center_word in enumerate(sentence):
            if center_word == -1:
                continue
            for j in range(left_offset, window_size + 1):
                context_word_index = i + j
                if (
                    0 <= context_word_index < len(sentence)
                    and context_word_index != i
                    and sentence[context_word_index] != -1
                ):
                    yield (
                        {
                            "center_word": center_word,
                            "target": sentence[context_word_index],
                        }
                    )

    def mask_sentence(self, sentence):
        mask_prob = torch.rand(len(sentence))
        filter_idx = torch.where(mask_prob < self.subsampling_prob[sentence])
        sentence[filter_idx] = -1
        return sentence
