from pathlib import Path
import torch
import numpy as np
from tqdm import tqdm
import json
import pandas as pd
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from transformers import (AutoModelForCausalLM,
                          AutoTokenizer,
                          BitsAndBytesConfig,
                          HfArgumentParser,
                          Trainer,
                          TrainingArguments,
                          DataCollatorForLanguageModeling,
                          EarlyStoppingCallback,
                          AutoModelForSequenceClassification,
                          pipeline,
                          logging,
                          set_seed)
from peft import AutoPeftModelForCausalLM

class Evaluater:
    def __init__(self) -> None:
        '''
        params:task_name: str: name of the task
        params:train_type: str: type of the training data in ['train','train_small']
        params:val_type: str: type of the validation data in ['test']
        params:test_type: str: type of the test data in ['test','test_big']
        '''
        print('Evaluating the model...')

    def merge_model(self, finetuned_model_dir:Path, labels, label2id, id2label):
        tokenizer = AutoTokenizer.from_pretrained(str(finetuned_model_dir))
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = 'right'
        compute_dtype = getattr(torch, "float16")   
        model = AutoPeftModelForCausalLM.from_pretrained(
                    str(finetuned_model_dir)+'/',
                    torch_dtype=compute_dtype,
                    return_dict=False,
                    low_cpu_mem_usage=True,
                    device_map='auto',
                )

        model.to('cuda:0')
        model = model.merge_and_unload()
        model.to('cuda:0')
        model.config.pad_token_id = tokenizer.pad_token_id

        if 'mistral' in str(finetuned_model_dir):
            model.config.sliding_window = 4096
       
        print('merged_model', model)
        
        return model, tokenizer

    def predict(self, test, model, tokenizer, labels, output_dir, response_key):
        eval_file = output_dir / "eval_pred.csv"
        print('eval_file', eval_file)
        if eval_file.exists():
            eval_file.unlink()
        
        for i in tqdm(range(len(test))):
            prompt = test[i]["text"]
            pipe = pipeline(task="text-generation", 
                        model=model, 
                        tokenizer=tokenizer, 
                        max_new_tokens = 10,
                       )
            result = pipe(prompt)
            answer = result[0]['generated_text'].split(response_key)[-1]
            found = False
            for l in labels:
                if l.lower() in answer.lower():
                    pred = l
                    found = True
                    break
            if not found:
                pred="none"
            a = pd.DataFrame({ "true":[test[i]['label']], "pred":[pred], "answer":[answer], "prompt":[prompt]})
            a.to_csv(eval_file,mode="a",index=False,header=not eval_file.exists())
        

    def evaluate(self, test, labels, label2id, id2label, model_dir, output_dir=None, do_predict = True, model=None,
                 tokenizer=None, response_key=''):
        
        """
        Evaluate the model using accuracy, classification report
        :param y_true: True labels
        :param y_pred: Predicted labels
        :param label2id: Dictionary mapping labels to ids
        """
        if output_dir is None:
                output_dir = Path(model_dir)
       
        if do_predict:
            if model is None or tokenizer is None:
                model, tokenizer = self.merge_model(model_dir, labels, label2id, id2label)
            start_time = pd.Timestamp.now()
            self.predict(test, model, tokenizer, labels, output_dir, response_key)
            end_time = pd.Timestamp.now()
            inference_time = end_time - start_time
            inference_time = inference_time.total_seconds()
            with open (output_dir / "inference_time.json", 'w') as f:
                json.dump({'inference_time':int(inference_time)}, f, indent=4)
        
        if not (output_dir / "eval_pred.csv").exists():
            print('No predictions found.')
            return
        
        df = pd.read_csv(output_dir / "eval_pred.csv")
        none_nr = len(df[df['pred'] == 'none'])
        y_pred = df["pred"]
        y_true = df["true"]
        
        # Map labels to ids
        label2id['none'] = len(label2id)
        map_func = lambda label: label2id[label]
        y_true = np.vectorize(map_func)(y_true)
        y_pred = np.vectorize(map_func)(y_pred)
        
        # Calculate accuracy
        accuracy = accuracy_score(y_true=y_true, y_pred=y_pred)
        print(f'Accuracy: {accuracy:.3f}')
        
        # Generate accuracy report
        if none_nr > 0:
            target_names = labels+['none']
            labels = labels+['none']
        else:
            target_names = labels
        class_report = classification_report(y_true=y_true, y_pred=y_pred, target_names=target_names, output_dict=True, zero_division=0)
        print('\nClassification Report:')
        class_report['none_nr'] = none_nr
        print(class_report)

        eval_file = output_dir / "eval_report.json"
        with open(str(eval_file), 'w') as f:
            json.dump(class_report, f, indent=4)
        
        



def main():
    ''

if __name__ == "__main__":
    main()