# -*- coding: utf-8 -*-

from ..data.vocab import Vocabulary


def convert_words_to_ids(source_words, target_words, common_vocab: Vocabulary):
    oov_words = Vocabulary(initial=())

    unk_id = common_vocab.unk_id
    source_ids = common_vocab.words_to_ids(source_words)
    source_ids_ext = []
    for word, word_id in zip(source_words, source_ids):
        if word_id == unk_id:
            word_id = len(common_vocab) + oov_words.add(word)
        source_ids_ext.append(word_id)

    if target_words is not None:
        unk_id = common_vocab.unk_id
        target_ids = common_vocab.words_to_ids(target_words)
        target_ids_ext = []
        for word, word_id in zip(target_words, target_ids):
            if word_id == unk_id:
                # current `word` appears in source oov words
                oov_id = oov_words.word_to_id(word, -1)
                if oov_id != -1:  # current `word` is out of source OOV
                    word_id = len(common_vocab) + oov_id
            target_ids_ext.append(word_id)
    else:
        target_ids = None
        target_ids_ext = None

    return (source_ids, source_ids_ext), (target_ids, target_ids_ext), oov_words


def convert_ids_to_words(target_ids_ext, common_vocab: Vocabulary, oov_words: Vocabulary,
                         source_words=None,
                         truncate_to_eos=False):
    words = []
    if truncate_to_eos:
        eos_id = common_vocab.eos_id

    for target_id in target_ids_ext:
        if target_id >= len(common_vocab):
            target_id -= len(common_vocab)
            try:
                word = oov_words.id_to_word(target_id)
            except ValueError:
                raise ValueError(f'Error: generated word corresponds to source OOV {target_id}'
                                 f' but this example only has {len(oov_words)} source OOVs')
        else:
            if truncate_to_eos and target_id == eos_id:
                break
            word = common_vocab.id_to_word(target_id)
        words.append(word)

    return words
