import os.path
import sys
import torch
import pytorch_lightning as pl
from knowledge_enhanced_extraction import Extractor, KnowEnhancer
from data_loader import MyDataModule
from util import split_label_words, SUBJECT_START, SUBJECT_END, OBJECT_START, OBJECT_END, CLASS_ID

import json
from transformers import AutoModelForMaskedLM, AutoModelForCausalLM, AutoTokenizer, top_k_top_p_filtering, set_seed

os.environ["TOKENIZERS_PARALLELISM"] = "false"

if __name__ == '__main__':

    set_seed(42)

    # data parameters llm_re_new/
    dataset_dir = 'dataset/'
    dataname = sys.argv[1]
    k_shot = 'k-shot/{}-1'.format(sys.argv[2])
    data_dir = os.path.join(dataset_dir, dataname, k_shot)

    # slm parameters
    slm_dir = 'roberta-large'
    lr = 5e-5
    weight_decay = 0.01
    warmup_ratio = 0.1
    max_seq_length = 128
    bs = 8
    accumulate_grad_batches = 2 if 'sem' in dataname else 4
    reload_dataloaders_every_n_epochs = 3
    max_epochs = 9

    # llm parameters
    llm_device = torch.device('cuda:1')
    c_model = 'llama'
    # llm_dir = 'llama2-13b-chat'
    llm_dir = 'Meta-Llama-3-8B-Instruct'
    dtype = torch.bfloat16

    c_time = '0504'
    # save parameters
    context_save_dir = f'context_save/{c_time}/{dataname}/{k_shot}'
    if not os.path.exists(context_save_dir):
        os.makedirs(context_save_dir)

    model_save_dir = f'model_save/{c_time}/{dataname}/{k_shot}'
    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir)

    # relation labels, add special tokens
    rel2id_f = os.path.join(os.path.dirname(data_dir), 'rel2id.json')
    with open(rel2id_f, "r") as file:
        rel2id = json.load(file)

    Na_rel = 0
    for k, v in rel2id.items():
        if k == "NA" or k == "no_relation" or k == "Other":
            Na_rel = v
            break

    n_rels = len(rel2id)
    slm_tokenizer = AutoTokenizer.from_pretrained(slm_dir)
    slm_tokenizer.add_special_tokens(
        {'additional_special_tokens': [SUBJECT_START, SUBJECT_END, OBJECT_START, OBJECT_END]})
    class_tokens = [CLASS_ID.format(i + 1) for i in range(n_rels)]
    slm_tokenizer.add_special_tokens({'additional_special_tokens': class_tokens})

    rel2tokens = split_label_words(slm_tokenizer, list(rel2id))

    # load llm
    know_enhancer = KnowEnhancer(llm_dir, dtype, llm_device)

    # load slm
    slm_model = AutoModelForMaskedLM.from_pretrained(slm_dir)
    extractor = Extractor(slm_model, slm_tokenizer, class_tokens, rel2tokens, Na_rel,
                          lr=lr, weight_decay=weight_decay, warmup_ratio=warmup_ratio, aux_ratio=0.03)
    data_module = MyDataModule(data_dir, rel2id, slm_tokenizer, know_enhancer, context_save_dir,
                               max_seq_length=max_seq_length, bs=bs, nw=1, p_topk=2)
    extractor.num_training_steps = len(data_module.train_examples) * max_epochs // accumulate_grad_batches

    # slm trainer
    save_call_bk = pl.callbacks.ModelCheckpoint(dirpath=model_save_dir, filename='{epoch}', every_n_epochs=1,
                                                save_last=True)
    trainer = pl.Trainer(devices=[0], accelerator="gpu", max_epochs=max_epochs,
                         accumulate_grad_batches=accumulate_grad_batches,
                         num_sanity_val_steps=0, enable_progress_bar=True,
                         reload_dataloaders_every_n_epochs=reload_dataloaders_every_n_epochs,
                         callbacks=[save_call_bk])

    trainer.fit(extractor, datamodule=data_module, val_dataloaders=None)
    trainer.test(ckpt_path='last', datamodule=data_module)
