from model.custom_module.backbone import GPT2LMAndValueHeadModel
from transformers import AutoModelForCausalLM, AutoConfig
from typing import Union
from torch.distributed import get_rank
import torch.distributed as dist
import torch
import os


def create_reg_module(args):
    if args['reg_type'] == 'kl':
        if args['model_type'] == 'origin':
            return OriginModel(args)
        else:
            return None
    else:
        return None

class OriginModel(torch.nn.Module):
    def __init__(self, args):
        super(OriginModel, self).__init__()
        self.model = GPT2LMAndValueHeadModel.from_pretrained(
            args['path_or_name'], 
            **args['model_kwargs'])
        self.device = torch.device(f"cuda:{args['device']}")
        self.keep_on_device = args['keep_on_device']
        for name, parameter in self.model.named_parameters():
            parameter.requires_grad = False 
        
    def forward(self, input_ids=None, attention_mask=None, labels=None):
        outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            labels=labels)
        return outputs
        

class OtherModel(torch.nn.Module):
    def __init__(self, args):
        super(OtherModel, self).__init__()
        config = AutoConfig.from_pretrained(args['path_or_name'])
        self.model = AutoModelForCausalLM.from_pretrained(
            args['path_or_name'], config=config, device_map={"": device}, torch_dtype=torch.float16)
        self.device = torch.device(f"cuda:{args['device']}")
        self.keep_on_device = args['keep_on_device']
        for name, parameter in self.model.named_parameters():
            parameter.requires_grad = False 

    def forward(self, inputs):
        outputs = self.model(**inputs, return_dict=True, use_cache=False)
        return outputs
        


