import os
import hydra
from omegaconf import OmegaConf

import json
import logging
import os
from pathlib import Path
import time

from torch import save
from torch.profiler import profile, record_function, ProfilerActivity

from utils.general import get_time_dict_path_full_data, log_config
from run_scripts.main_decorator import main_decorator
from datasets import concatenate_datasets


log = logging.getLogger()

OmegaConf.register_new_resolver(
    "to_string", lambda x: x.replace("/", "_").replace("-", "_")
)
OmegaConf.register_new_resolver(
    "get_patience_value", lambda dev_size: 1000 if dev_size == 0 else 5
)


@main_decorator
def run_full_data(config, work_dir: Path or str):
    # Imports inside function to set environment variables before imports
    from construct import construct_model
    from utils.data.load_data import load_data

    # Log config so that it is visible from the console
    log_config(log, config)
    log.info("Loading data...")
    train_instances, dev_instances, test_instances, labels_or_id2label = load_data(
        config.data, config.model.type, config.framework.name
    )
    log.info(
        f"Training size: {len(train_instances)}, test size: {len(test_instances)}, dev size: {len(dev_instances)}"
    )

    embeddings, word2idx = None, None

    # Initialize time dict
    time_dict_path = get_time_dict_path_full_data(config)

    log.info("Fitting the model...")
    model = construct_model(
        config,
        config.model,
        dev_instances,
        config.framework.name,
        labels_or_id2label,
        "model",
        time_dict_path,
        embeddings=embeddings,
        word2idx=word2idx,
    )

    # Calculate time for training on 2, 6 and 12 percent of data
    total_instances = len(train_instances)
    percentages = [2]#, 6, 12]

    training_times = {}
    for percent in percentages:
        num_instances = int((percent / 100.0) * total_instances)
        subset = train_instances.select(range(num_instances))

        log.info(f"Training on {percent}% of data ({num_instances} instances).")

        start_time = time.time()
        # with profile(activities=[ProfilerActivity.CPU], record_shapes=False) as prof:
        #     with record_function("model_training"):
        #Train the model on the subset
        model.fit(subset, None)

        end_time = time.time()

        elapsed_time = end_time - start_time
        training_times[percent] = elapsed_time

        log.info(f"Time taken to train on {percent}% of data: {elapsed_time:.2f} seconds.")

        time_file_path = work_dir / f"training_time_{percent}_percent.json"
        with open(time_file_path, "w") as f:
            json.dump({f"{percent}_percent": elapsed_time}, f)

        log.info(f"Training time for {percent}% of data saved to {time_file_path}")

        # table = prof.key_averages(group_by_input_shape=True).table(sort_by="cuda_time_total", row_limit=100)
        # print(table)

    # Optionally, if you want to save all times together as well, you can do so:
    all_times_file_path = work_dir / "all_training_times.json"
    with open(all_times_file_path, "w") as f:
        json.dump(training_times, f)

    log.info(f"All training times saved to {all_times_file_path}")


    # model.fit(train_instances, None)
    # try:
    #     model.model.save_pretrained("tmp_full")
    # except:
    #     import pdb

    #     pdb.set_trace()

    # try:
    #     dev_metrics = model.evaluate(dev_instances)
    # except:
    #     import pdb

    #     pdb.set_trace()
    # log.info(f"Dev metrics: {dev_metrics}")

    # try:
    #     test_metrics = model.evaluate(test_instances)
    # except:
    #     import pdb

    #     pdb.set_trace()
    # log.info(f"Test metrics: {test_metrics}")

    # with open(work_dir / "dev_metrics.json", "w") as f:
    #     json.dump(dev_metrics, f)

    # with open(work_dir / "metrics.json", "w") as f:
    #     json.dump(test_metrics, f)

    # if config.dump_model:
    #     model.model.save_pretrained(work_dir / "model_checkpoint")
    # log.info("Done.")


@hydra.main(
    config_path=os.environ["HYDRA_CONFIG_PATH"],
    config_name=os.environ["HYDRA_CONFIG_NAME"],
)
def main(config):
    run_full_data(config)


if __name__ == "__main__":
    main()
