
#controlled by model_args.freeze and config parameters
def freeze(model, config):
    '''
    :param model: GPT2E model
    :param config:
    each config.freeze_X can have values: True, False, x
    where x is an array of numbers (identifying the layer numbers that are NOT
    going to be frozen (i.e. freeze all except these numbers).
    :return: GPT2E model
    '''
    # print(f'Freezing with? parameters: {config.get_freeze_params}')
    if config.freeze_percentage:
        print(f'Freezing {config.freeze_percentage} % of the parameters')
        model_params = [p for p in model.parameters()]
        params = len(model_params)
        freezing_index = round(config.freeze_percentage * params / 100)
        #freeze
        for i in range(0,freezing_index):
            model_params[i].requires_grad = False

        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f'New trainable param number {trainable_params}')
    else:
        print('Freezing model based on GPT2EConfig parameters')
        for name, p in model.named_parameters():
            name = name.lower()
            # layer norm
            if 'ln' in name:
                if config.freeze_ln and p.requires_grad:
                    if type(config.freeze_ln) is bool:
                        p.requires_grad = not config.freeze_ln
                    else:
                        if not True in ['.'+(str(n))+'.' in name for n in config.freeze_ln]:
                            p.requires_grad = False


            # word embeddings
            elif 'wte' in name:
                if config.freeze_emb and p.requires_grad:

                    if type(config.freeze_emb) is bool:
                        p.requires_grad = not config.freeze_emb
                    else:
                        if not True in ['.'+(str(n))+'.' in name for n in config.freeze_emb]:
                            p.requires_grad = False

            # word position embeddings
            elif 'wpe' in name:
                if config.freeze_pos and p.requires_grad:
                    if type(config.freeze_pos) is bool:
                        p.requires_grad = not config.freeze_pos
                    else:
                        if not True in ['.'+(str(n))+'.' in name for n in config.freeze_pos]:
                            p.requires_grad = False
            # ffnn
            elif 'mlp' in name:
                if config.freeze_ff and p.requires_grad:
                    if type(config.freeze_ff) is bool:
                        p.requires_grad = not config.freeze_ff
                    else:
                        if not True in ['.'+(str(n))+'.' in name for n in config.freeze_ff]:
                            p.requires_grad = False

            elif '.attn' in name:
                if config.freeze_attn and p.requires_grad:
                    if type(config.freeze_attn) is bool:
                        p.requires_grad = not config.freeze_attn
                    else:
                        if not True in ['.'+(str(n))+'.' in name for n in config.freeze_attn]:
                            p.requires_grad = False


            # EGBlock
            elif 'gate' in name:
                if config.freeze_gate and p.requires_grad:
                    if type(config.freeze_gate) is bool:
                        p.requires_grad = not config.freeze_gate
                    else:
                        if not True in ['.'+(str(n))+'.' in name for n in config.freeze_gate]:
                            p.requires_grad = False

            elif 'entities' in name:
                if config.freeze_entities and p.requires_grad:
                    if type(config.freeze_entities) is bool:
                        p.requires_grad = not config.freeze_entities
                    else:
                        if not True in ['.'+(str(n))+'.' in name for n in config.freeze_entities]:
                            p.requires_grad = False


            else: #biases and other layers
                p.requires_grad = False

        for p in model.lm_head.parameters():
            if config.freeze_lm and p.requires_grad:
                if type(config.freeze_lm) is bool:
                    p.requires_grad = not config.freeze_lm
                else:
                    if not True in ['.'+(str(n))+'.' in name for n in config.freeze_lm]:
                        p.requires_grad = False



        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f'New trainable param number: {trainable_params}.')

    return model