from pathlib import Path
import torch
from transformers import (AutoModelForCausalLM,
                          AutoTokenizer,
                          BitsAndBytesConfig,
                          HfArgumentParser,
                          Trainer,
                          TrainingArguments,
                          DataCollatorForLanguageModeling,
                          EarlyStoppingCallback,
                          AutoModelForSequenceClassification,
                          pipeline,
                          logging,
                          set_seed)

class ModelLoader:
    def __init__(self) -> None:
        print('Loading the model...')
        self.bnb_config = BitsAndBytesConfig(
            load_in_4bit = True, # Activate 4-bit precision base model loading
            bnb_4bit_use_double_quant = True, # Activate nested quantization for 4-bit base models (double quantization)
            bnb_4bit_quant_type = "nf4",# Quantization type
            bnb_4bit_compute_dtype = torch.bfloat16, # Compute data type for 4-bit base models
            )
        
    
    def load_model_from_path_name_version(self, model_root_path:str, model_name:str, model_version:str, device_map:str="auto"):

        model_path = model_root_path+"/models--"+model_name+"/"+model_version
        print('Loading model from...', model_path)

        tokenizer = AutoTokenizer.from_pretrained(model_path)
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = 'right'

        model = AutoModelForCausalLM.from_pretrained(model_path,
                                                    quantization_config = self.bnb_config,
                                                    device_map = device_map, 
                                                    ) 
        if 'Mistral' in model_name:
            model.config.sliding_window = 4096
        return model, tokenizer


        

def main():
    model_loader = ModelLoader()

if __name__ == "__main__":
    main()