import tensorflow as tf

from BERT.BERT_wordpieces import BERTwordpieces
from BERT.DefBERT_CLS import DefBERTCLS
from BERT.DefBERT import DefBERT
from baselines.Additive import Additive
from baselines.FastTextModel import FastTextModel
from baselines.HeadModel import HeadModel
from baselines.W2VModel import W2VModel
from defiNNet.DefiNNet import DefiNNet


def model_by(coordinates, entity_on_error=True):
    model_name, pretained_model_path = coordinates
    if model_name == 'additive_model':
        model = Additive(pretained_model_path, binary=True)
        return model

    if model_name == 'head_model':
        model = HeadModel(pretained_model_path, binary=True)
        return model

    if model_name == 'w2v':
        model = W2VModel(pretained_model_path, binary=True, entity_on_error=entity_on_error)
        return model

    if model_name == 'fasttext':
        model = FastTextModel(pretained_model_path)
        return model

    if model_name == 'definnet':
        denn = tf.keras.models.load_model('denn.h5')
        model = DefiNNet(denn, pretained_model_path)
        return model

    if model_name.startswith('ITERATIVE_denn'):
        i = model_name.split("_")[-1].split('.')[0]
        denn = tf.keras.models.load_model('data/models/denn_'+str(i)+'.h5')
        model = DefiNNet(denn, pretained_model_path)
        return model

    if model_name == 'defBERT':
        model = DefBERT('bert-base-uncased')
        return model

    if model_name == 'BERT_wordpieces':
        model = BERTwordpieces('bert-base-uncased')
        return model

    if model_name == 'defBERT_CLS':
        model = DefBERTCLS('bert-base-uncased')
        return model


