import argparse
import shutil
import glob
import os
import random
import datasets
import numpy as np
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from transformers import T5ForConditionalGeneration, GPT2Tokenizer
from transformers import DataCollatorForSeq2Seq
from transformers import get_scheduler

import tensorflow as tf
from spell_check_evaluation.evaluate_v2.evaluate import evaluate


def eval_model(ds, task_prefix, model, tokenizer, args, name_test, dir_path):
    
    os.makedirs(dir_path, exist_ok=True)
    answers = []
    sources = []
    corrections = []
    for _, row in tqdm(ds.iterrows()):
        srs = row["source"].strip("\ufeff")
        corr = row["correction"].strip("\ufeff")
        init_encodings = tokenizer(srs, return_tensors='pt')
        
        text = task_prefix + srs
        encodings = tokenizer(text, return_tensors='pt')
        
        for k, v in encodings.items():
            encodings[k] = v.to(args.device)
        generated_tokens = model.generate(
                    **encodings, max_length=int(len(init_encodings["input_ids"][0]))+1
        ).cpu()
        ans = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        answers.append(ans[0])
        sources.append(srs)
        corrections.append(corr)
        

    with open(f"{dir_path}/ans_spell{name_test}.txt", "w") as ans, \
        open(f"{dir_path}/src_spell{name_test}.txt", "w") as src, \
        open(f"{dir_path}/corr_spell{name_test}.txt", "w") as corr:
        for s, c, a in zip(sources, corrections, answers):
            ans.write(a + "\n")
            src.write(s + "\n")
            corr.write(c + "\n")

    metrics = evaluate(sources, corrections, answers)
    return metrics
