from argparse import ArgumentParser
from rl4lms.envs.text_generation.logging_utils import Tracker


import rl4lms.envs.text_generation.training_utils as training_utils_trainers

import yaml
import os

import jsonlines
import torch
import numpy as np
import sys

from collections import defaultdict
sys.path.append('..')
import llm_prompts
import llm_reward_fns

import flair 

def load_dict(dict_path):
    out_dict = defaultdict(list)
    old_out_dict = np.load(dict_path, allow_pickle=True)
    out_dict.update(old_out_dict.item())
    return out_dict

def main(
    args,
    config_path: str,
    project_name: str,
    experiment_name: str,
    base_path_to_store_results: str,
    entity_name: str,
    log_to_wandb: bool,
):
    # load the config file
    with open(config_path, "r") as fp:
        config = yaml.safe_load(fp)
    # edit the yaml file
    if args.data_path != None:
        config['datapool']['args']['path_to_save_dataset'] = args.data_path
    if args.test_data_path != None:
        config['datapool']['args']['path_to_save_test_dataset'] = args.test_data_path
        print("Test dataset is at " + str(args.test_data_path))
    if args.path_to_save_twitter_dataset is not None:
        print("We are using the twitter data as the test " + str(args.path_to_save_twitter_dataset))
        config['datapool']['args']['path_to_save_dataset'] = args.path_to_save_twitter_dataset
        config['datapool']['args']['path_to_save_test_dataset'] = args.path_to_save_twitter_dataset
        config['datapool']['args']['eval_gold_comms'] = False
    
    config['datapool']['args']['samples_to_use'] = args.samples_to_use
    config['datapool']['args']['test_samples_to_use'] = args.samples_to_use
    
    if args.samples_to_use is not None and int(args.samples_to_use) < 20:
        # if there aren't a lot of samples, one iteration
        print("Only running one iteration")
        config['train_evaluation']['n_iters'] = 2
    
    if args.eval_every is not None:
        config['train_evaluation']['eval_every'] = int(args.eval_every)
    
    if "supervised" not in config["alg"]["id"]:
        # we are running RL
        # change the model name
        # this will load the model from here
        model_name_to_load = args.model_name + '_' + str(args.epoch_to_evaluate) + '.0'
        if args.continue_rl_training:
            model_name_to_load = args.model_name_rl + '_' + str(args.rl_epoch_to_evaluate)
            print("The new model name to continue_rl_training is " + str(model_name_to_load))
        config['alg']['policy']['args']['model_name'] = model_name_to_load
        config['reward_fn']['args']['metric']['reward_fn'] = args.reward_fn_to_use
        config['reward_fn']['args']['metric']['openai_url'] = args.openai_url
        print("We are using this as the reward_fn " + str(args.reward_fn_to_use))
        print("We are using this as the openai_url " + str(args.openai_url))
        
        if args.not_run_gpt_evaluation:
            config['reward_fn']['args']['metric']['not_run_gpt_evaluation'] = True
        
        if args.run_smaller_batch_size:
            print("Batch size smaller")
            config['alg']['args']['batch_size'] = 3
            config['train_evaluation']['eval_batch_size'] = 3
            
        if args.batch_size is not None:
            config['alg']['args']['batch_size'] = args.batch_size
            config['train_evaluation']['eval_batch_size'] = args.batch_size
        
    else:
        if args.model_name != None:
            config['alg']['model_name'] = args.model_name  
        if args.epoch_to_evaluate is not None:
            config['alg']['model_epoch'] = int(args.epoch_to_evaluate)
        else:
            config['alg']['model_epoch'] = None
  
    
    # load tracker
    tracker = Tracker(
        base_path_to_store_results,
        config,
        project_name,
        experiment_name,
        entity_name,
        log_to_wandb,
    )

    for metric in config["train_evaluation"]["metrics"]:
        if metric["id"] == "community_analysis":
            metric["args"]["save_path"] = os.path.join(
                base_path_to_store_results, project_name, experiment_name
            )

    print("Building samples from given_samples ")
    sys.stdout.flush()
    samples_dict_to_train = None
        
    if args.train_rl_after_supervised:
        model_path = os.path.join(tracker._run_path, "model")
        model_name_to_load = model_path
        config['alg']['policy']['args']['model_name'] = model_name_to_load
        print("Loading model to train_rl_after_supervised from path " + str(model_path))

    # instantiate the trainer here
    if "supervised" in config["alg"]["id"]:
        
        config['alg']['model_name'] = 't5-base'
        config['tokenizer']['model_name'] = 't5-base'
        
        trainer = training_utils_trainers.SupervisedTrainer(
            tokenizer_config=config["tokenizer"],
            datapool_config=config["datapool"],
            alg_config=config["alg"],
            train_eval_config=config["train_evaluation"],
            tracker=tracker,
            samples_dict_to_train=samples_dict_to_train,
            build_samples_from_given_samples=False,
            args=args,
        )
    else:
        
        
        print("Running the RL model")

        trainer = training_utils_trainers.OnPolicyTrainer(
            tokenizer_config=config["tokenizer"],
            datapool_config=config["datapool"],
            reward_config=config["reward_fn"],
            env_config=config["env"],
            on_policy_alg_config=config["alg"],
            train_eval_config=config["train_evaluation"],
            tracker=tracker,
        )
    trainer.train_and_eval(only_run_eval=args.should_run_training_or_just_eval_training, continue_rl_training=args.continue_rl_training, epoch_to_start_training=args.rl_epoch_to_evaluate, reward_fn=args.reward_fn_to_use, not_run_gpt_evaluation=args.not_run_gpt_evaluation, just_eval=args.just_eval, should_save=args.should_save, run_curriculum_learning=args.run_curriculum_learning, run_supervised_training_in_between=args.run_supervised_training_in_between, args=args, openai_url=args.openai_url)


if __name__ == "__main__":
    parser = ArgumentParser(description="Fine-tune LM to generate controlled text")
    parser.add_argument("--config_path", type=str, help="path to the config file")
    parser.add_argument(
        "--project_name", type=str, help="WANDB project name", default="rl4lm_exps"
    )
    parser.add_argument(
        "--experiment_name",
        type=str,
        help="WANDB experiment name",
        default="rl4lm_experiment",
    )
    parser.add_argument(
        "--entity_name", type=str, help="WANDB entity name", default=None
    )
    parser.add_argument(
        "--base_path_to_store_results",
        type=str,
        help="Base path to store experiment results",
        default=os.getcwd(),
    )
    parser.add_argument(
        "--log_to_wandb", action="store_true", help="Whether to use wandb logging"
    )
    parser.add_argument("--just_evaluate",  action='store_true', help="True if you want to just evaluate the model.")
    parser.add_argument("--eval_reddit",  action='store_true', help="True if you want to just evaluate on the reddit dataset.")
    parser.add_argument("--eval_no_focus_area",  action='store_true', help="True if you want to just evaluate when we do not provide any focus area.")
    parser.add_argument("--eval_gold_focus_area",  action='store_true', help="True if you want to just evaluate when we provide the gold focus area.")
    parser.add_argument("--evaluate_bias",  action='store_true', help="True if you want to evaluate the bias as well.")
    parser.add_argument("--should_run_training_or_just_eval_training",  action='store_true', help="True if you want to just eval the training models using the default metrics.")
    
    parser.add_argument("--model_name",  type=str, help="The model to load and evaluate.", default=None)
    parser.add_argument("--data_path", type=str, help="path to the dataset", default=None)
    parser.add_argument("--test_data_path", type=str, help="path to the dataset", default=None)
    
    parser.add_argument("--path_to_save_twitter_dataset", type=str, help="path to the Twitter dataset", default=None)
    parser.add_argument("--save_predicted_comms", action='store_true', help="Where to save the communities for Twitter")
    
    parser.add_argument("--epoch_to_evaluate", type=int, default=None, help="Which model epoch should be evaluated")
    parser.add_argument("--eval_every", type=int, default=None, help="How often to eval the model")
    parser.add_argument("--should_save", action='store_true', help="True if you want to save the model every time")
    
    
    parser.add_argument("--samples_to_use", type=int, default=None, help="How many dataset samples to use. Don't set it if you want to use everything.")
    
    # useful to load a RL model
    parser.add_argument("--continue_rl_training", action='store_true', help="True if you want to continue training the RL model.")
    parser.add_argument("--model_name_rl",  type=str, help="The path for the RL model to load and evaluate.", default=None)
    parser.add_argument("--rl_epoch_to_evaluate", type=int, default=None, help="Which RL model epoch should be evaluated")
        
    parser.add_argument("--reward_fn_to_use", type=str, help="The RL reward function you want to use to train the model.", default=None)
    parser.add_argument("--openai_url", type=str, help="The openAI link you want to use to train the model.", default=None)
    
    
    parser.add_argument("--not_run_gpt_evaluation", action='store_true', help="True if you do not want to evaluate on GPT, it will probably just evaluate the ROUGE score and other automatic metrics then.")
    parser.add_argument("--just_eval", action='store_true', help="True if you just want to evaluate the model")
    
    
    parser.add_argument("--run_smaller_batch_size", action='store_true', help="True if you want to run a smaller batch size")
    parser.add_argument("--batch_size", type=int, default=None, help="Batch size to run the model with")
    
    parser.add_argument("--run_curriculum_learning", action='store_true', help="True if you want to train the model using curriculum learning")
    parser.add_argument("--run_supervised_training_in_between", action='store_true', help="True if you want to run supervised training in between on successful RL examples")
    parser.add_argument("--train_config_path", type=str, help="path to the training yaml config")
            
    
    args = parser.parse_args()

    main(
        args,
        args.config_path,
        args.project_name,
        args.experiment_name,
        args.base_path_to_store_results,
        args.entity_name,
        args.log_to_wandb,
    )