from pathlib import Path
import torch
from transformers import (AutoModelForCausalLM,
                          AutoTokenizer,
                          BitsAndBytesConfig,
                          HfArgumentParser,
                          Trainer,
                          TrainingArguments,
                          DataCollatorForLanguageModeling,
                          EarlyStoppingCallback,
                          AutoModelForSequenceClassification,
                          pipeline,
                          logging,
                          set_seed)
from peft import AutoPeftModelForSequenceClassification

class ModelLoader:
    def __init__(self) -> None:
        print('Loading the model...')
        self.bnb_config = BitsAndBytesConfig(
            load_in_4bit = True, # Activate 4-bit precision base model loading
            bnb_4bit_use_double_quant = True, # Activate nested quantization for 4-bit base models (double quantization)
            bnb_4bit_quant_type = "nf4",# Quantization type (fp4 or nf4)
            bnb_4bit_compute_dtype = torch.bfloat16, # Compute data type for 4-bit base models
            )
        
    def load_finetuned_model(self, finetuned_model_dir:str,labels, label2id, id2label, device_map:str="auto"):
        print('Loading finetuned model from...', finetuned_model_dir)
         
        tokenizer = AutoTokenizer.from_pretrained(str(finetuned_model_dir))
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = 'right'
        compute_dtype = getattr(torch, "float16")   
        model =  AutoPeftModelForSequenceClassification.from_pretrained(
                        str(finetuned_model_dir)+'/',
                        torch_dtype=compute_dtype,
                        return_dict=False,
                        low_cpu_mem_usage=True,
                        device_map='auto',
                        num_labels = len(labels),
                    )
        model.to('cuda:0')

        model.config.id2label = id2label
        model.config.label2id = label2id
        
        model = model.merge_and_unload()
        model.to('cuda:0')
        model.config.pad_token_id = tokenizer.pad_token_id

        if 'mistral' in str(finetuned_model_dir):
            model.config.sliding_window = 4096
        
        return model, tokenizer


    def load_model_from_path(self, model_path:str,labels, label2id, id2label, device_map:str="auto"):

        print('Loading model from...', model_path)	

        tokenizer = AutoTokenizer.from_pretrained(model_path)
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = 'right'

        model = AutoModelForSequenceClassification.from_pretrained(model_path,
                                                 quantization_config = self.bnb_config,
                                                 device_map = device_map, 
                                                 num_labels = len(labels),
                                                 ) 
        model.config.pad_token_id = tokenizer.pad_token_id
        model.config.id2label = id2label
        model.config.label2id = label2id
        if 'Mistral' in model_path:
            model.config.sliding_window = 4096
        return model, tokenizer
    
    def load_model_from_path_name_version(self, model_root_path:str, model_name:str, model_version:str, labels, label2id, id2label, device_map:str="auto"):

        model_path = model_root_path+"/models--"+model_name+"/"+model_version
        print('Loading model from...', model_path)	

        tokenizer = AutoTokenizer.from_pretrained(model_path)
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = 'right'

        model = AutoModelForSequenceClassification.from_pretrained(model_path,
                                                 quantization_config = self.bnb_config,
                                                 device_map = device_map, 
                                                 num_labels = len(labels),
                                                 ) 
        model.config.pad_token_id = tokenizer.pad_token_id
        model.config.id2label = id2label
        model.config.label2id = label2id
        if 'Mistral' in model_name:
            model.config.sliding_window = 4096
        return model, tokenizer


        

def main():
    model_loader = ModelLoader()

if __name__ == "__main__":
    main()