import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "0" 

import pandas as pd
import torch
import torch.nn as nn
from datasets import load_dataset
from tqdm import tqdm
from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM

model_names = ['flan-t5-small', 'flan-t5-base', 'flan-t5-large', 'flan-t5-xl', 'flan-t5-xxl', 'mistralai/Mistral-7B-Instruct-v0.1', 'HuggingFaceH4/zephyr-7b-alpha']

files = 'LIST of paths of different prompt files'

for m in model_names:
    if(m[0]=='f'):
        tokenizer = T5Tokenizer.from_pretrained("google/"+m)
        model = T5ForConditionalGeneration.from_pretrained("google/"+m, torch_dtype=torch.float16)
    else:
        tokenizer = AutoTokenizer.from_pretrained(m)
        model = AutoModelForCausalLM.from_pretrained(m, torch_dtype=torch.float16)
        tokenizer.pad_token = tokenizer.eos_token

    model = model.to('cuda')
    
    for f in files:
        df = pd.read_excel(f+'.xlsx')
        df = df[df.columns[:2]]
        
        anno = []
        bs = 128
        for i in tqdm(range(0,df.shape[0],bs)):
            prompt = []
            for j in range(i,i+bs):
                try:
                    if(m[0]=='f'):
                        prompt.append(df.text.iloc[j])
                    else:
                        prompt.append(df.text.iloc[j]+" Answer in one word only.") # only for mistral & zephyr
                except:
                    break
            
            inputs = tokenizer(prompt, padding=True, return_tensors="pt").to('cuda')
            with torch.no_grad():
                outputs = model.generate(**inputs, max_new_tokens=5)
            outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)

            if(m[0]!='f'):
                for i in range(len(outputs)):
                    outputs[i] = outputs[i][len(prompt[i]):]
            
            anno += outputs
            del inputs, outputs
            torch.cuda.empty_cache()

        df[m] = anno
        df.to_excel(f+'.xlsx', index=False)

    del model
    torch.cuda.empty_cache()