from torch.optim import Adam, SGD


def create_optimizer(base_net, args):
    optimizer = Adam([
        {'params': [p for p in base_net.encoder.parameters() if p.requires_grad], 'lr': args.encoder_learning_rate},
        {'params': [p for p in base_net.emission_net.parameters() if p.requires_grad], 'lr': args.learning_rate},
        {'params': [p for p in base_net.crf_decoder.parameters() if p.requires_grad], 'lr': args.learning_rate}
    ], lr=args.learning_rate, weight_decay=args.gamma, amsgrad=True)
    return optimizer


def create_optimizer_protonet(proto_net, args):
    optimizer = Adam([
        {'params': [p for p in proto_net.encoder.parameters() if p.requires_grad], 'lr': args.encoder_learning_rate},
        {'params': [p for p in proto_net.protonet_encoder.parameters() if p.requires_grad], 'lr': args.learning_rate},
        {'params': [p for p in proto_net.trans_nn.parameters() if p.requires_grad], 'lr': args.learning_rate}
    ], lr=args.learning_rate, weight_decay=args.gamma, amsgrad=True)
    return optimizer


def create_meta_optimizer_protonet(proto_net, args):
    optimizer = Adam([
        {'params': [p for p in proto_net.encoder.parameters() if p.requires_grad], 'lr': args.meta_encoder_learning_rate},
        {'params': [p for p in proto_net.protonet_encoder.parameters() if p.requires_grad], 'lr': args.meta_learning_rate},
        {'params': [p for p in proto_net.trans_nn.parameters() if p.requires_grad], 'lr': args.meta_learning_rate}
    ], lr=args.meta_learning_rate, weight_decay=args.meta_gamma, amsgrad=True)
    return optimizer


def create_outer_optimizer(base_net, args):
    optimizer = SGD([
        {'params': [p for p in base_net.encoder.parameters() if p.requires_grad], 'lr': args.meta_encoder_learning_rate},
        {'params': [p for p in base_net.emission_net.parameters() if p.requires_grad], 'lr': args.meta_learning_rate},
        {'params': [p for p in base_net.crf_decoder.parameters() if p.requires_grad], 'lr': args.meta_learning_rate}
    ], lr=args.meta_learning_rate, weight_decay=args.meta_gamma)
    return optimizer


def create_outer_optimizer_protonet(base_net, args):
    optimizer = SGD([
        {'params': [p for p in base_net.encoder.parameters() if p.requires_grad], 'lr': args.meta_encoder_learning_rate},
        {'params': [p for p in base_net.protonet_encoder.parameters() if p.requires_grad], 'lr': args.meta_learning_rate},
        {'params': [p for p in base_net.trans_nn.parameters() if p.requires_grad], 'lr': args.meta_learning_rate}
    ], lr=args.meta_learning_rate, weight_decay=args.meta_gamma)
    return optimizer
