import torch.nn as nn
from transformers import AutoModel
import config

class DialogueSummarizationModel(nn.Module):
    def __init__(self):
        super(DialogueSummarizationModel, self).__init__()
        self.model = config.modelCheckpoint
        # if config.LOG: print("Model HF Device Map: ", self.model.hf_device_map)
        
        # Freeze all parameters
        for param in self.model.parameters():
            param.requires_grad = False
        # Very Specific to what I am doing: Unfreeze the last 2 layers + norm layer + last linear layer
        for param in self.model.model.layers[31].parameters(): param.requires_grad = True
        for param in self.model.model.layers[30].parameters(): param.requires_grad = True
        for param in self.model.model.norm.parameters(): param.requires_grad = True
        for idx, each in enumerate(self.model.parameters()): 
            if idx>280: each.requires_grad = True

    def forward(self, input_ids, attention_mask, labels):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        print(outputs)
        return outputs



