import numpy

from transformers import PreTrainedTokenizer


class Dictionary_wordpiece(object):
    def __init__(self, tokenizer: PreTrainedTokenizer):
        self.tokenizer = tokenizer
        self.tokenizer.add_special_tokens({'eos_token': '</s>'})
        self.vocab = self.tokenizer.get_vocab()

    def add_word(self, token):
        self.tokenizer.add_tokens([token])

    def __len__(self):
        return len(self.vocab)

    def __contains__(self, key):
        return key in self.vocab

    def __getitem__(self, item):
        ids = self.tokenizer.encode(item, add_special_tokens=False, add_prefix_space=True)
        if isinstance(ids, int):
            ids = [ids]
        return ids
        # return self.tokenizer.encode(item)

    def idx2word(self, idx):
        tokens = self.tokenizer.decode(idx)
        if not isinstance(tokens, list):
            tokens = [tokens]
        return tokens


class Dictionary(object):
    def __init__(self):
        self.word2idx = {'<unk>': 0}
        self.idx2word = ['<unk>']
        self.word2frq = {}

    def __len__(self):
        return len(self.idx2word)

    def __getitem__(self, item):
        if item in self.word2idx:
            return self.word2idx[item]
        else:
            unk = self.unkify(item, -1)
            return self.word2idx[unk]

    def get_idx(self, item, loc=-1):
        if item in self.word2idx:
            return self.word2idx[item]
        else:
            unk = self.unkify(item, loc)
            return self.word2idx[unk]

    def __contains__(self, key):
        return key in self.word2idx

    def add_word(self, word):
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        if word not in self.word2frq:
            self.word2frq[word] = 1
        else:
            self.word2frq[word] += 1
        return self.word2idx[word]

    def add_unk(self, word, loc):
        if word not in self.word2idx:
            unk = self.unkify(word, loc)
            if unk not in self.idx2word:
                self.idx2word.append(unk)
                self.word2idx[unk] = len(self.idx2word) - 1

    def rebuild_by_freq(self, thd=2):
        self.word2idx = {'<unk>': 0}
        self.idx2word = ['<unk>']

        for k, v in self.word2frq.items():
            if v >= thd and (not k in self.idx2word):
                self.idx2word.append(k)
                self.word2idx[k] = len(self.idx2word) - 1

        print('Number of words:', len(self.idx2word))
        return len(self.idx2word)

    def unkify(self, word, loc):
        sb = 'UNK'
        wlen = len(word)
        numCaps = 0
        hasDigit = False
        hasDash = False
        hasLower = False
        for i in range(wlen):
            ch = word[i]
            if ch.isdigit():
                hasDigit += 1
            # elif ch == '-':
            #     hasDash = True
            elif ch.isalpha():
                if ch.islower():
                    hasLower = True
                elif ch.istitle():
                    hasLower = True
                    numCaps += 1
                else:
                    numCaps += 1
        ch0 = word[0]
        lowered = word.lower()
        if ch0.isupper() or ch0.istitle():
            if loc == 0 and numCaps == 1:
                sb += '-INITC'
                if lowered in self.idx2word:
                    sb += '-KNOWNLC'
            else:
                sb += '-CAPS'
        elif (not ch0.isalpha()) and numCaps > 0:
            sb += '-CAPS'
        elif hasLower:
            sb += '-LC'

        if hasDigit and (not hasLower):
            sb += '-NUM'

        # if hasDash:
        #     sb += '-DASH'

        if lowered.endswith('s') and wlen >= 3:
            ch2 = lowered[-2]
            if (ch2 != 's' and ch2 != 'i' and ch2 != 'u'):
                sb += '-s'
        elif len(word) >= 5 and (not hasDash) and (not (hasDigit and numCaps > 0)):
            if lowered.endswith('ed'):
                sb += '-ed'
            elif lowered.endswith('ing'):
                sb += '-ing'
            elif lowered.endswith('ion'):
                sb += '-ion'
            elif lowered.endswith('er'):
                sb += '-er'
            elif lowered.endswith('est'):
                sb += '-est'
            elif lowered.endswith('ly'):
                sb += '-ly'
            elif lowered.endswith('ity'):
                sb += '-ity'
            elif lowered.endswith('y'):
                sb += '-y'
            elif lowered.endswith('al'):
                sb += '-al'

        return sb

    # def unkify_dyer(self, token, loc):
    #     if len(token.rstrip()) == 0:
    #         result = 'UNK'
    #     else:
    #         numCaps = 0
    #         hasDigit = False
    #         hasDash = False
    #         hasLower = False
    #         for char in token.rstrip():
    #             if char.isdigit():
    #                 hasDigit = True
    #             elif char == '-':
    #                 hasDash = True
    #             elif char.isalpha():
    #                 if char.islower():
    #                     hasLower = True
    #                 elif char.isupper():
    #                     numCaps += 1
    #         result = 'UNK'
    #         lower = token.rstrip().lower()
    #         ch0 = token.rstrip()[0]
    #         if ch0.isupper():
    #             if numCaps == 1:
    #                 result = result + '-INITC'
    #                 if lower in self.idx2word:
    #                     result = result + '-KNOWNLC'
    #             else:
    #                 result = result + '-CAPS'
    #         elif not (ch0.isalpha()) and numCaps > 0:
    #             result = result + '-CAPS'
    #         elif hasLower:
    #             result = result + '-LC'
    #         if hasDigit:
    #             result = result + '-NUM'
    #         if hasDash:
    #             result = result + '-DASH'
    #         if lower[-1] == 's' and len(lower) >= 3:
    #             ch2 = lower[-2]
    #             if not (ch2 == 's') and not (ch2 == 'i') and not (ch2 == 'u'):
    #                 result = result + '-s'
    #         elif len(lower) >= 5 and not (hasDash) and not (hasDigit and numCaps > 0):
    #             if lower[-2:] == 'ed':
    #                 result = result + '-ed'
    #             elif lower[-3:] == 'ing':
    #                 result = result + '-ing'
    #             elif lower[-3:] == 'ion':
    #                 result = result + '-ion'
    #             elif lower[-2:] == 'er':
    #                 result = result + '-er'
    #             elif lower[-3:] == 'est':
    #                 result = result + '-est'
    #             elif lower[-2:] == 'ly':
    #                 result = result + '-ly'
    #             elif lower[-3:] == 'ity':
    #                 result = result + '-ity'
    #             elif lower[-1] == 'y':
    #                 result = result + '-y'
    #             elif lower[-2:] == 'al':
    #                 result = result + '-al'
    #
    #     return result
