# -*- coding: utf-8 -*-

import itertools
from collections import Counter
from typing import Dict, List

import numpy as np

from ..common.logger import LOGGER, smart_open_file

VOCAB_PREFIX = '##########'


class Tokens:
    UNK = '___UNK___'
    PAD = '___PAD___'
    SOS = '___SOS___'
    EOS = '___EOS___'
    MARK = '__MARK___'
    EMPTY = '___EMPTY___'
    CHAR_SOS = '\02'
    CHAR_SOSx3 = CHAR_SOS * 3
    CHAR_EOS = '\03'
    CHAR_EOSx3 = CHAR_EOS * 3
    SOW = '\04'
    EOW = '\05'
    SOT = '___SOT___'
    EOT = '___EOT___'  # end of token


class Vocabulary:
    def __init__(self, name='anonymous', initial=(Tokens.PAD, Tokens.UNK)):
        self.name = name

        self._words: List[str] = []
        self._word_to_id: Dict[str, int] = {}
        self._words_count = Counter()

        for token in initial:
            self._add(token)
        self._offset = len(initial)

    @property
    def special_tokens(self):
        return self._words[:self._offset]

    def __contains__(self, key):
        return key in self._word_to_id

    def __iter__(self):
        return iter(self._words)

    def __len__(self):
        return len(self._words)

    def __repr__(self):
        return f'Vocab{{{self.name} size={len(self)} {self.special_tokens}}}'

    def _add(self, word):
        assert word not in self._word_to_id

        self._word_to_id[word] = index = len(self._words)
        self._words.append(word)

        return index

    def total_count(self):
        return sum(self._words_count.values())

    def count(self, word):
        return self._words_count.get(word, 0)

    def iter_counts(self):
        for word in self._words:
            yield word, self._words_count[word]

    def update(self, iterable):
        for item in iterable:
            if isinstance(item, str):
                self.add(item)
            else:
                self.add(*item)

    def add(self, word, count=1):
        index = self._word_to_id.get(word)
        if index is None:
            self._word_to_id[word] = index = len(self._words)
            self._words.append(word)

        if count is not None:
            self._words_count[word] += count

        return index

    @property
    def lookup(self):
        return self._word_to_id

    @property
    def pad_id(self):
        return self._word_to_id[Tokens.PAD]

    @property
    def eos_id(self):
        return self._word_to_id[Tokens.EOS]

    @property
    def sos_id(self):
        return self._word_to_id[Tokens.SOS]

    @property
    def unk_id(self):
        return self._word_to_id[Tokens.UNK]

    def vocab_or_dict(self, word, default_id=None):
        index = self._word_to_id.get(word, default_id)
        if index is None:
            raise KeyError(f'"{word}" not found in {self}')
        return index

    def id_to_word(self, word_id):
        return self._words[word_id]

    def ids_to_words(self, word_ids, truncate_to_eos=False):
        if truncate_to_eos:
            eos = self._word_to_id.get(Tokens.EOS)
            assert eos is not None
            k = 0
            while k < len(word_ids) and word_ids[k] != eos:
                k += 1
            word_ids = word_ids[:k]
        return [self.id_to_word(i) for i in word_ids]

    def word_to_id(self, word, default=None):
        return self._word_to_id.get(word, default)

    def words_to_ids(self, words, unk_id=None):
        if unk_id is None:
            unk_id = self._word_to_id[Tokens.UNK]

        return [self._word_to_id.get(word, unk_id) for word in words]

    def state_dict(self):
        return {
            'name': self.name,
            '_offset': self._offset,
            '_words': self._words,
            '_word_to_index': self._word_to_id,
            '_words_count': self._words_count
        }

    def load_state_dict(self, state):
        self.name = state['name']
        self._offset = state['_offset']
        self._words = state['_words']
        self._word_to_id = state['_word_to_index']
        self._words_count = state['_words_count']
        return self

    def _normalize_initial(self, initial):
        """
        If `initial' is True, new vocabulary will use original initial tokens,
        If `initial' is None or False, new vocabulary will have no initial tokens,
        If `initial' is a list, then it will be pass to the constructor of Vocabulary.
        """
        if initial is None or initial is False:
            return ()

        if initial is True:
            return self.special_tokens

        assert isinstance(initial, (list, tuple))
        return initial

    def copy_top_k(self, k, name=None, initial=True):
        if name is None:
            name = f'{self.name}-top-{k}'

        vocab = Vocabulary(name, initial=self._normalize_initial(initial))
        for word, count in self._words_count.most_common(k):
            vocab._words_count[word] = count

        vocab._words.extend(self._words[self._offset:])
        vocab._word_to_id = {word: word_id for word_id, word in enumerate(vocab._words)}
        return vocab

    def copy_without_low_frequency(self, threshold=2, name=None, initial=True):
        if name is None:
            name = f'{self.name}>={threshold}'

        vocab = Vocabulary(name, initial=self._normalize_initial(initial))
        for word, count in self._words_count.items():
            if count >= threshold:
                vocab._words_count[word] = count

        vocab._words.extend(vocab._words_count.keys())
        vocab._word_to_id = {word: word_id for word_id, word in enumerate(vocab._words)}
        return vocab

    def to_file(self, path_or_fp):
        with smart_open_file(path_or_fp, 'w') as fp:
            fp.write(f'{VOCAB_PREFIX} {self.name} {self._offset} {len(self)}\n')
            for word in self._words:
                fp.write(f'{word}\t{self._words_count[word]}\n')

    @classmethod
    def from_file(cls, path_or_fp):
        with smart_open_file(path_or_fp, 'r') as fp:
            line = fp.readline()
            assert line.startswith(VOCAB_PREFIX)
            _, name, offset, size = line.strip().split()
            offset, size = int(offset), int(size)

            items = []
            while len(items) < size:
                line = fp.readline().strip('\n')
                if not line:
                    break
                word, count = line.split('\t')
                items.append((word, int(count)))

        vocab = cls(name, initial=[word for word, _ in items[:offset]])
        for word, count in items:
            vocab.add(word, count)

        return vocab


class VocabularySet:
    def __init__(self):
        self._vocabs = {}

    def __repr__(self):
        return '{' + ',\n '.join(str(vocab) for vocab in self._vocabs.values()) + '}'

    def state_dict(self):
        return {name: vocab.state_dict() for name, vocab in self._vocabs.items()}

    def load_state_dict(self, states):
        self._vocabs.clear()
        for name, state in states.items():
            self._vocabs[name] = Vocabulary().load_state_dict(state)
        return self

    def get(self, name):
        return self._vocabs.get(name)

    def set(self, vocab):
        self._vocabs[vocab.name] = vocab

    def get_or_new(self, name, **kwargs):
        vocab = self._vocabs.get(name)
        if vocab is None:
            return self.new(name, **kwargs)
        return vocab

    def new(self, name, **kwargs):
        vocab = Vocabulary(name, **kwargs)
        self._vocabs[name] = vocab
        return vocab

    def to_file(self, path):
        with smart_open_file(path, 'w') as fp:
            for vocab in self._vocabs.values():
                vocab.to_file(fp)

    @classmethod
    def from_file(cls, path):
        vocabs = cls()
        with smart_open_file(path, 'r') as fp:
            while True:
                try:
                    vocabs.set(Vocabulary.from_file(fp))
                except AssertionError:
                    break
        return vocabs


def smart_remove_low_frequency_words(vocab: Vocabulary,
                                     min_coverage=0.95, high_threshold=50, remove_and_copy_fn=None,
                                     verbose=True):
    if remove_and_copy_fn is None:
        remove_and_copy_fn = Vocabulary.copy_without_low_frequency

    total_count = vocab.total_count()

    low, high = 0, max(high_threshold, 0)

    if verbose:
        LOGGER.info('Find bset min_frequency in range %s for %s and coverage>=%.4f',
                    (low, high), vocab, min_coverage)

    while low != high:
        mid = (low + high + 1) >> 1

        new_vocab = remove_and_copy_fn(vocab, threshold=mid)
        coverage = new_vocab.total_count() / total_count

        if verbose:
            LOGGER.info('Try min_frequency=%d => coverage=%.4f', mid, coverage)

        if coverage >= min_coverage:
            low = mid
        else:
            high = mid - 1

    if verbose:
        LOGGER.info('Use min_frequency=%d', low)

    new_vocab = remove_and_copy_fn(vocab, threshold=low)
    new_vocab.name = vocab.name
    return new_vocab


def lookup_words(words, vocab_or_dict, default_id=None,
                 sos_and_eos=False,
                 tensor_fn=np.zeros,
                 dtype=np.int64):
    if isinstance(vocab_or_dict, Vocabulary):
        vocab_or_dict = vocab_or_dict.lookup

    if default_id is None:
        default_id = vocab_or_dict[Tokens.UNK]

    word_ids = tensor_fn((len(words) + (2 if sos_and_eos else 0),), dtype=dtype)
    if sos_and_eos:
        word_ids[0] = vocab_or_dict[Tokens.SOS]

    index = -1
    for index, word in enumerate(words, (1 if sos_and_eos else 0)):
        word_ids[index] = vocab_or_dict.get(word, default_id)

    if sos_and_eos:
        word_ids[index + 1] = vocab_or_dict[Tokens.EOS]

    return word_ids


def lookup_characters(words, vocab_or_dict, default_id=None,
                      max_word_length=20,
                      sow_and_eow=True,
                      sos_and_eos=False,
                      tensor_fn=np.zeros,
                      dtype=np.int64,
                      return_lengths=False):
    assert max_word_length >= 10

    if isinstance(vocab_or_dict, Vocabulary):
        vocab_or_dict = vocab_or_dict.lookup

    if default_id is None:
        default_id = vocab_or_dict[Tokens.UNK]

    length = len(words)
    if sos_and_eos:
        length += 2
        words = itertools.chain(Tokens.CHAR_SOSx3, words, Tokens.CHAR_EOSx3)

    char_ids = tensor_fn((length, max_word_length), dtype=dtype)
    char_lengths = tensor_fn((length, ), dtype=dtype)

    max_word_length -= ((len(Tokens.SOW) + len(Tokens.EOW)) if sow_and_eow else 0)

    ellipsis = '..'
    split_point = (max_word_length - len(ellipsis)) // 2
    for index, word in enumerate(words):
        if len(word) > max_word_length:
            word = word[:split_point] + ellipsis + word[-split_point:]
        if sow_and_eow:
            word = Tokens.SOW + word + Tokens.EOW

        char_lengths[index] = len(word)
        for char_index, char in enumerate(word):
            char_ids[index, char_index] = vocab_or_dict(char, default_id)

    if return_lengths:
        return char_lengths, char_ids
    return char_ids
