import torch
from utils.util import *
import torch.nn as nn

from transformers import T5ForConditionalGeneration

class T5(nn.Module):
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        self.T5_model = T5ForConditionalGeneration.from_pretrained(self.opt.T5_model)#, device_map='auto')
        
    def forward(self, input_ids=None, attention_mask=None, **kwargs):
        if input_ids != None:
            if input_ids.dim() == 3:
                input_ids = input_ids.view(input_ids.size(0), -1)
        if attention_mask != None:
            if attention_mask.dim() == 3:
                attention_mask = attention_mask.view(attention_mask.size(0), -1)
            
        return self.T5_model.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )

    def generate(self, input_ids, attention_mask, do_sample=False):
        if input_ids != None:
            if input_ids.dim() == 3:
                input_ids = input_ids.view(input_ids.size(0), -1)
        if attention_mask != None:
            if attention_mask.dim() == 3:
                attention_mask = attention_mask.view(attention_mask.size(0), -1)

        return self.T5_model.generate( 
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length= self.opt.max_length,
            do_sample = do_sample
        )
    
    def mul_generate(self, input_ids, attention_mask, do_sample=False):
        if input_ids != None:
            if input_ids.dim() == 3:
                input_ids = input_ids.view(input_ids.size(0), -1)
        if attention_mask != None:
            if attention_mask.dim() == 3:
                attention_mask = attention_mask.view(attention_mask.size(0), -1)

        return self.T5_model.generate( 
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length= self.opt.max_length,
            num_beams=10,
            num_return_sequences=5,
            do_sample = do_sample
        )