import torch

def get_module_by_name(modulep, names):
    modulec = modulep
    for name in names:
        modules = getattr(modulec, name)
        modulec = modules
    return modules

def add_noise(w1, w2, sigma_method, sigma_theta):
    add_result = w1 + w2
    if sigma_method == 0:
        return add_result
    elif sigma_method == 1:
        s1 = (w1 > 0) & (add_result < 0)
        add_result = add_result.masked_fill(s1, sigma_theta)
        s2 = (w1 < 0) & (add_result > 0)
        add_result = add_result.masked_fill(s2, -sigma_theta)
        return add_result

def add_gauss_noise_to_model(model, model_name, sigma, sigma_method, sigma_theta):
    if "gpt2" in model_name:
        module_names = [
            ['attn','c_attn'],
            ['attn','c_proj'],
            ['mlp','c_fc'],
            ['mlp','c_proj'],
        ]
        modulep = model.transformer.h
        if "medium" in model_name:
            layer_num = 24
        else:
            layer_num = 12
    elif "roberta" in model_name:
        module_names = [
            ['attention','self','query'],
            ['attention','self','value']
        ]
        modulep = model.roberta.encoder.layer
        if "base" in model_name:
            layer_num = 12
        if "large" in model_name:
            layer_num = 24
    for l in range(layer_num):
        for module_name in module_names:
            module = get_module_by_name(modulep[l], module_name)
            noise = torch.randn(module.weight.data.shape, device = module.weight.data.device)
            module.weight.data = add_noise(module.weight.data, noise*sigma, sigma_method, sigma_theta)
    return model




    