import torch
from transformers import T5Tokenizer, T5Config

from T5_IP import CustomT5Model

# 初始化tokenizer和模型
tokenizer = T5Tokenizer.from_pretrained('t5-small')
config = T5Config.from_pretrained('t5-small')
model = CustomT5Model(config)
model.load_state_dict(torch.load('path_to_your_model.pth'))
model.eval()


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def generate_text(input_text, max_length=512):
    model.eval()
    input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
    with torch.no_grad():
        outputs = model.generate(input_ids, max_length=max_length)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)


test_inputs = [
    "Translate English to German: How are you?",
    "Summarize: The quick brown fox jumps over the lazy dog."
]

for input_text in test_inputs:
    output_text = generate_text(input_text)
    print(f"Input: {input_text}")
    print(f"Output: {output_text}")
    print("-" * 50)