import torch
from best.Dict import Dict
from sklearn.exceptions import NotFittedError

# special token constants

UNK = 0
EOS = 1

UNK_WORD = '__UNK__'
EOS_WORD = '__EOS__'


class TokenEncoder(object):

    def __init__(self, lower=True, max_size=None, boundary_marker=False):
        self.lower = lower
        self.max_size = max_size
        self.boundary_marker = boundary_marker

    def fit(self, docs):
        self.vocab = Dict(lower=self.lower)
        self.vocab.addSpecial(UNK_WORD, UNK)
        self.vocab.addSpecial(EOS_WORD, EOS)

        for doc in docs:
            for sent in doc.tokenized['sentences']:
                for token in sent['tokens']:
                    self.vocab.add(token['word'])

        if self.max_size is not None:
            self.vocab.prune(self.max_size)

        return self

    def transform(self, docs):
        if not hasattr(self, 'vocab'):
            raise NotFittedError
        res = []
        eos_word = EOS_WORD if self.boundary_marker else None
        for doc in docs:
            sents = []
            for sent in doc.tokenized['sentences']:
                toks = (tok['word'] for tok in sent['tokens'])
                sents.append(self.vocab.convertToIdx(toks,
                                                     unkWord=UNK_WORD,
                                                     eosWord=eos_word))
            res.append(torch.cat(sents))
        return res

    def size(self):
        return self.vocab.size()
