from time import time
import numpy as np
from models.text_cnn import TextR, TextO, u_TextO, TextD
from models.text_rnn import TextRNNR, TextRNNO, u_TextRNNO, TextRNND
from models.fudan import Fudan

def timing(f):
    def wrapper(*args, **kwargs):
        t = time()
        r = f(*args, **kwargs)
        print(f'# Time {f.__name__}: {time() - t}')
        return r
    return wrapper


def get_model(cfg, word_embeddings):
    if cfg.task == 'fudan':
        model = Fudan(cfg, word_embeddings)
    elif cfg.model.name == 'textcnn':
        model = dict({})
        model['rep'] = TextR(cfg, word_embeddings)
        for t in cfg.data['tasks']:
            if cfg.task == 'uncertain':
                model[t] = u_TextO(cfg)
            else:
                model[t] = TextO(cfg)
        if 'dis' in cfg.exp and cfg.exp['dis']:
            model['dis'] = TextD(cfg)
    elif cfg.model.name == 'lstm':
        model = dict({})
        model['rep'] = TextRNNR(cfg, word_embeddings)
        for t in cfg.data['tasks']:
            if cfg.task == 'uncertain':
                model[t] = u_TextRNNO(cfg)
            else:
                model[t] = TextRNNO(cfg)
        if 'dis' in cfg.exp and cfg.exp['dis']:
            model['dis'] = TextRNND(cfg)
    return model


def gradient_normalizers(grads, losses, normalization_type):
    gn = {}
    if normalization_type == 'l2':
        for t in grads:
            gn[t] = np.sqrt(np.sum([gr.pow(2).sum().item() for gr in grads[t]]))
    elif normalization_type == 'loss':
        for t in grads:
            gn[t] = losses[t]
    elif normalization_type == 'loss+':
        for t in grads:
            gn[t] = losses[t] * np.sqrt(np.sum([gr.pow(2).sum().item() for gr in grads[t]]))
    elif normalization_type == 'none':
        for t in grads:
            gn[t] = 1.0
    else:
        print('ERROR: Invalid Normalization Type')
    return gn

