import time

from datasets import load_dataset
import transformers
from transformers import Trainer, TrainingArguments, AutoTokenizer

from utils import get_compute_metrics, preprocess_logits_for_metrics
from base_model import CustomGPT2Config
from train import MainConfig, get_dataset_and_tokenizer, load_model


if __name__ == "__main__":
    # Dataset
    print("Loading dataset...")
    tokenized_dataset, tokenizer = get_dataset_and_tokenizer()
    print(tokenized_dataset)

    # Config
    print("Loading model...")

    model_config = CustomGPT2Config.from_pretrained(MainConfig.load_checkpoint)
    model = load_model(model_config)


    # Training
    print("Training...")

    training_args = TrainingArguments(
        per_device_eval_batch_size=64,
        output_dir="./outputs",
        seed=MainConfig.seed,
        report_to=None,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset["train"],
        eval_dataset=tokenized_dataset["validation"],
        compute_metrics=get_compute_metrics(tokenizer),
        preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    )

    results = trainer.evaluate()

    print(results)

    # text = "def dijkstra(graph):"
    # print("Generating code...")
    # generated = model.generate(tokenizer(text, return_tensors="pt").input_ids, max_length=100)
    # print(tokenizer.decode(generated[0], skip_special_tokens=True))
