import torch
import config
import pickle
from transformers import AutoTokenizer, BartForConditionalGeneration
from torch.utils.data import DataLoader, Dataset
from rouge_score import rouge_scorer


# model = pickle.load(open(f"{config.MODEL_SAVE_PATH}" + f"{config.MODEL}".split('/')[1], 'rb'))
        
class SummarizationEvaluator:
    def __init__(self, model_name, saved_path, dataloader, device):
        self.device = device
        self.model = config.modelCheckpoint
        self.tokenizer = config.tokenizerCheckpoint
        self.model.load_state_dict(torch.load(saved_path, map_location=self.device), strict=False)
        self.dataloader = dataloader
        # self.model.to(self.device)    # device_map='auto' in model definition
        self.scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

    def evaluate(self):
        self.model.eval()
        total_rouge1, total_rouge2, total_rougeL = 0, 0, 0
        total_samples = 0

        # evaluation_dataset = self._prepare_dataset()
        evaluation_dataloader = self.dataloader

        with torch.no_grad():
            for batch in evaluation_dataloader:
                input_ids, attention_mask, labels = batch
                
                # device_map='auto' in model definition
                # input_ids, attention_mask, labels = input_ids.to(self.device), attention_mask.to(self.device), labels.to(self.device)
                
                # Generate summaries and decode
                summary_ids = self.model.generate(input_ids, max_length=150, num_beams=4, length_penalty=2.0, early_stopping=True)
                generated_summary = self.tokenizer.decode(summary_ids, skip_special_tokens=True)

                # Calculate Rouge scores
                reference = [self.tokenizer.decode(label, skip_special_tokens=True) for label in labels]
                scores = self.scorer.score(reference, generated_summary)

                # Update the total Rouge scores
                total_rouge1 += scores['rouge1'].fmeasure
                total_rouge2 += scores['rouge2'].fmeasure
                total_rougeL += scores['rougeL'].fmeasure
                total_samples += 1

        # Calculate and return the average Rouge scores
        average_rouge1 = total_rouge1 / total_samples
        average_rouge2 = total_rouge2 / total_samples
        average_rougeL = total_rougeL / total_samples
        if LOG: 
            print(f"Average Rouge-1: {average_rouge1}")
            print(f"Average Rouge-2: {average_rouge2}")
            print(f"Average Rouge-L: {average_rougeL}")

        return average_rouge1, average_rouge2, average_rougeL





