import os
import torch
import numpy as np
import scipy.special
import logging

from torchfly.flylogger import FlyLogger
from torchfly.flyconfig import FlyConfig
from torchfly.training import Trainer
import torchfly.distributed as distributed
from torchfly.utilities import set_random_seed
from omegaconf import OmegaConf

from memformerE3_base_flymodel import MemformerFlyModel
from memformers.recurrent_training.recurrent_trainer import RecurrentTrainer
from persuasion_dataloader import DataLoaderHelper

logger = logging.getLogger(__name__)

def main():
    config = FlyConfig.load("configs/persuasion_config/memformerE3_base_10%.yaml")
    set_random_seed(config.training.random_seed)

    data_helper = DataLoaderHelper(config)
    train_dataloader = data_helper.train_loader_fn()
    valid_dataloader = data_helper.valid_loader_fn()

    model = MemformerFlyModel(config)
    model.configure_metrics()

    trainer = RecurrentTrainer(config.training, model)

    trainer.loss_weights = np.power(2, np.arange(1, trainer.config.time_horizon + 1))
    trainer.loss_weights = trainer.loss_weights / np.linalg.norm(trainer.loss_weights, 1)
    trainer.loss_weights = scipy.special.softmax(trainer.loss_weights / 1e5) * trainer.config.time_horizon

    with FlyLogger(config.flylogger) as flylogger:
        logger.info("loss weights", str(trainer.loss_weights))

        # save_config
        with open("config.yaml", "w") as f:
            OmegaConf.save(config, f)
        
        logger.info(model.model.recurrent_training_cell.model_config)
        logger.info(config)

        trainer.train(config.training, train_dataloader, valid_dataloader)


if __name__ == "__main__":
    main()
