import time
import sys
from six.moves import cPickle
import os
import torch
import numpy as np
from nlptext.sentence import Sentence
from pprint import pprint

try:
    import tensorflow as tf
except ImportError:
    print("Tensorflow not installed; No tensorboard logging.")
    tf = None


# 1. From FieldSequence_Para to fldseq_dirname
# 2. Save Corresponding HyperInformation, such as raw input, raw target.

def add_summary_value(writer, key, value, iteration):
    summary = tf.Summary(value=[tf.Summary.Value(tag=key, simple_value=value)])
    writer.add_summary(summary, iteration)


def lr_decay(optimizer, epoch, lr_decay_rate, init_lr):
    lr = init_lr / (1 + lr_decay_rate * epoch)
    # print("\tLearning rate is set as:", lr)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return optimizer


def lr_decay_scale(optimizer, lr_decay_rate, current_lr):
    lr = current_lr * (1-lr_decay_rate)
    # lr = init_lr / (1 + decay_rate * epoch)
    # print("\tLearning rate is set as:", lr)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return optimizer, lr


def build_optimizer(model, TRAIN, lr = None):
    if lr is None:
        if TRAIN['lr_warm_up']:
            lr = TRAIN['peak_lr']/ TRAIN['lr_warm_up_steps']
        else:
            lr = TRAIN['peak_lr']

    lr = lr
    
    optim_method = TRAIN['optim_method']
    momentum = TRAIN['momentum']
    
    l2 = TRAIN['l2']
    lr_decay_rate = TRAIN['lr_decay_rate']
    if optim_method.lower() == "sgd":
        optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, momentum=momentum, weight_decay=l2)
    elif optim_method.lower() == "adagrad":
        optimizer = torch.optim.Adagrad(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=l2)
    elif optim_method.lower() == "adadelta":
        optimizer = torch.optim.Adadelta(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=l2)
    elif optim_method.lower() == "rmsprop":
        optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=l2)
    elif optim_method.lower() == "adam":
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=l2)
    else:
        print("Optimizer illegal: %s" % (optim_method))
    return optim_method, lr, lr_decay_rate, optimizer
