import torch
from transformers import AutoTokenizer, AutoModel, AdamW
from transformers import LlamaTokenizer, LlamaForCausalLM
# from transformers import AutoTokenizer, BartForConditionalGeneration

# Training hyperparameters
LEARNING_RATE = 0.000001
BATCH_SIZE = 3
NUM_EPOCHS = 30

# Paths
MODEL = "klyang/MentaLLaMA-chat-7B" # Other options: Tianlin668/MentalBART , "Tianlin668/MentalT5", "gpt2"
DATA_DIR = "./dataset/"
MODEL_SAVE_PATH = "checkpoints/"
LAST_SAVED = None
# TRAIN_DATALOADER_SAVE_PATH = "./dataset/dataloader_checkpoints/MentalBART/train_dataloader.pkl"
# VAL_DATALOADER_SAVE_PATH = "./dataset/dataloader_checkpoints/MentalBART/val_dataloader.pkl"
# TEST_DATALOADER_SAVE_PATH = "./dataset/dataloader_checkpoints/MentalBART/test_dataloader.pkl"

# Compute related
DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

LOG = False
WAND = True

# Model 1: MentaLLaMA
tokenizerCheckpoint = LlamaTokenizer.from_pretrained(MODEL)
modelCheckpoint = LlamaForCausalLM.from_pretrained(MODEL) # use_flash_attention_2=True, torch_dtype=torch.float16

# Model 2: MentalBART
# tokenizerCheckpoint = AutoTokenizer.from_pretrained('Tianlin668/MentalBART')
# modelCheckpoint = BartForConditionalGeneration.from_pretrained('Tianlin668/MentalBART')

# Model 3: GPT2 or any general model
# tokenizer = AutoTokenizer.from_pretrained('klyang/MentaLLaMA-chat-7B')
# model = AutoModel.from_pretrained('klyang/MentaLLaMA-chat-7B')
# tokenizer.add_special_tokens({'pad_token': '[PAD]'})  # add pad token for GPT2 model 
# model.resize_token_embeddings(len(tokenizer))         # resize model embedding to include new tokens