import argparse
import shutil
import glob
import os
import random
import datasets
import numpy as np
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
from transformers import DataCollatorForSeq2Seq
from transformers import get_scheduler

import tensorflow as tf

from spell_check_evaluation.evaluate_v2.evaluate import evaluate


def eval_model(ds, model, tokenizer, args, name_test, dir_path):
    
    os.makedirs(dir_path, exist_ok=True)
    answers = []
    sources = []
    corrections = []
    for _, row in tqdm(ds.iterrows()):
        text = row["source"].strip("\ufeff")
        corr = row["correction"].strip("\ufeff")
        encodings = tokenizer(text, return_tensors='pt')
        for k, v in encodings.items():
            encodings[k] = v.to(args.device)
        generated_tokens = model.generate(
                    **encodings, 
            forced_bos_token_id=tokenizer.get_lang_id("ru"), 
            no_repeat_ngram_size=None
        )
        ans = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        answers.append(ans[0])
        sources.append(text)
        corrections.append(corr)
        

    with open(f"{dir_path}/ans_spell{name_test}.txt", "w") as ans, \
        open(f"{dir_path}/src_spell{name_test}.txt", "w") as src, \
        open(f"{dir_path}/corr_spell{name_test}.txt", "w") as corr:
        for s, c, a in zip(sources, corrections, answers):
            ans.write(a + "\n")
            src.write(s + "\n")
            corr.write(c + "\n")
    
    metrics = evaluate(sources, corrections, answers)
    
    return metrics
