import torch
from transformers import AutoTokenizer, AutoModel, AdamW
from transformers import LlamaTokenizer, LlamaForCausalLM
# from transformers import AutoTokenizer, BartForConditionalGeneration
import json


with open('dataset/data/map.json') as data_file:
    dialogue_map = json.load(data_file)

with open('dataset/data/phq_lexicon.json') as phq_file:
    lexicon = json.load(phq_file)

phq_lexicon = {}
for key in lexicon:
    for each in lexicon[key]:
        if key in phq_lexicon.keys():
            phq_lexicon[key].append(each.replace("_", " "))
        else:
            phq_lexicon[key] = [each.replace("_", " ")]


# Training hyperparameters
LEARNING_RATE = 0.00001
BATCH_SIZE = 1
NUM_EPOCHS = 30

# Paths
MODEL = "klyang/MentaLLaMA-chat-7B" # Other options: Tianlin668/MentalBART , "Tianlin668/MentalT5", "gpt2"
DATA_DIR = "./dataset/"
MODEL_SAVE_PATH = "checkpoints/"
LAST_SAVED = None
# TRAIN_DATALOADER_SAVE_PATH = "./dataset/dataloader_checkpoints/MentalBART/train_dataloader.pkl"
# VAL_DATALOADER_SAVE_PATH = "./dataset/dataloader_checkpoints/MentalBART/val_dataloader.pkl"
# TEST_DATALOADER_SAVE_PATH = "./dataset/dataloader_checkpoints/MentalBART/test_dataloader.pkl"

# Compute related
DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
# DEVICE = torch.device("cpu")

LOG = False
WAND = True

# Model 1: MentaLLaMA
tokenizerCheckpoint = LlamaTokenizer.from_pretrained(MODEL)
modelCheckpoint = LlamaForCausalLM.from_pretrained(MODEL, output_hidden_states=True) # use_flash_attention_2=True, torch_dtype=torch.float16

# Model 2: MentalBART
# tokenizerCheckpoint = AutoTokenizer.from_pretrained('Tianlin668/MentalBART')
# modelCheckpoint = BartForConditionalGeneration.from_pretrained('Tianlin668/MentalBART')

# Model 3: GPT2 or any general model
# tokenizer = AutoTokenizer.from_pretrained('klyang/MentaLLaMA-chat-7B')
# model = AutoModel.from_pretrained('klyang/MentaLLaMA-chat-7B')
# tokenizer.add_special_tokens({'pad_token': '[PAD]'})  # add pad token for GPT2 model 
# model.resize_token_embeddings(len(tokenizer))         # resize model embedding to include new tokens