from functools import partial
from typing import Any, Dict, List
import numpy as np
import torch
import sys
import yaml
import random
from rl4lms.data_pools.text_generation_pool import Sample
from rl4lms.envs.text_generation.env import TextGenEnv
from rl4lms.envs.text_generation.evaluation_utils import evaluate_on_samples
from rl4lms.envs.text_generation.utils_supervised import evaluate_on_samples as evaluate_supervised
from rl4lms.envs.text_generation.evaluation_utils import evaluate_on_twitter
from rl4lms.envs.text_generation.logging_utils import Tracker
from rl4lms.envs.text_generation.registry import (DataPoolRegistry,
                                                   MetricRegistry,
                                                   RewardFunctionRegistry,
                                                   PolicyRegistry,
                                                   AlgorithmRegistry,
                                                   WrapperRegistry)
from rl4lms.envs.text_generation.reward import RewardFunction
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import SubprocVecEnv
from transformers import (AutoTokenizer,
                          AutoModelForCausalLM,
                          AutoModelForSeq2SeqLM,
                          Trainer,
                          TrainingArguments,
                          DataCollatorForLanguageModeling,
                          DataCollatorForSeq2Seq)
from rl4lms.envs.text_generation.utils_supervised import (get_datasets_for_causal,
                                                           get_datasets_for_seq2seq,
                                                           tokenize_causal,
                                                           tokenize_seq2seq,
                                                           EvalCallack)
from rl4lms.envs.text_generation.warm_start import TrainerWarmStartMixin
from transformers import Trainer, TrainingArguments, TrainerCallback
import os 
import generate_complex_sentences_given_simple
from peft import LoraConfig, get_peft_model, TaskType


def build_tokenizer(tokenizer_config: Dict[str, Any]):
    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_config["model_name"])
    if tokenizer.pad_token is None and tokenizer_config.get("pad_token_as_eos_token", True):
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = tokenizer_config.get(
        "padding_side", "left")
    tokenizer.truncation_side = tokenizer_config.get(
        "truncation_side", "left")
    return tokenizer


def build_reward_fn(reward_config: Dict[str, Any]):
    reward_fn = RewardFunctionRegistry.get(reward_config["id"],
                                           reward_config.get("args", {}))
    return reward_fn


def build_metrics(metric_configs: List[Dict[str, Any]]):
    metrics = [MetricRegistry.get(metric_config["id"], metric_config.get("args", {}))
               for metric_config in metric_configs]
    return metrics


def build_datapool(datapool_config: Dict[str, Any]):

    def _get_datapool_by_split(split: str):
        kwargs = datapool_config.get("args", {})
        kwargs["split"] = split
        dp_split = DataPoolRegistry.get(datapool_config["id"], kwargs)
        return dp_split

    train_datapool = _get_datapool_by_split("train")
    val_datapool = _get_datapool_by_split("val")
    test_datapool = _get_datapool_by_split("test")

    samples_by_split = {
        "train": [(sample, weight)
                  for sample, weight in train_datapool],
        "val": [sample for sample, _ in val_datapool],
        "test": [sample for sample, _ in test_datapool]
    }
    return samples_by_split


def build_datapool_given_samples(samples_dict):
    # samples_dict layout: given an input -> (curr_text_input, predicted_comm_1, curr_focus_area, purity_for_community, given_gold_comm_1, given_gold_comm_2, all_usernames, [], purity_for_community_all_match)
    samples = []
    for ix, (given_input, data_for_it) in enumerate(samples_dict.items()):
        split = 'train'
        item = {}
        item['text'] = data_for_it[0]
        item['gold'] = data_for_it[2]
        item['comm_1'] = data_for_it[4]
        item['comm_2'] = data_for_it[5]
            
        sample = Sample(
            id=f"{split}_{ix}",
            prompt_or_input_text=str(item["text"]),
            references=item["gold"],
        )
                
        # we have the gold communities, so save it
        sample.comm_1 = item['comm_1']
        sample.comm_2 = item['comm_2']
        
        samples.append(sample)
    
    # we only need a training set here 
    samples_by_split = {
        "train": [(sample, 1.0)
                  for sample in samples],
        "val": [],
        "test": []
    }
    return samples_by_split

def build_env(env_config: Dict[str, Any],
              reward_fn: RewardFunction,
              tokenizer: AutoTokenizer,
              train_samples: List[Sample]):
    # vectoried env
    env_kwargs = {
        "reward_function": reward_fn,
        "tokenizer": tokenizer,
        "samples": train_samples,
    }
    env_kwargs = {**env_kwargs, **env_config.get("args", {})}
    env = make_vec_env(TextGenEnv,
                       n_envs=env_config.get(
                           "n_envs", 1),
                       vec_env_cls=SubprocVecEnv,
                       env_kwargs=env_kwargs)
    return env


def build_alg(alg_config: Dict[str, Any],
              env: TextGenEnv,
              tracker: Tracker,
              policy_state: Dict[str, Any],
              alg_state: Dict[str, Any]):
    # TBD - move these to a registry once the experimentation is done
    # Also switch to Sb3 algos when possible with minimal code adaptations
    policy_config = alg_config["policy"]
    policy_cls = PolicyRegistry.get(policy_config["id"])
    alg_cls = AlgorithmRegistry.get(alg_config["id"])

    policy_args = policy_config["args"]
    policy_args["state_dict"] = policy_state
    alg_kwargs = {
        "policy": policy_cls,
        "env": env,
        "policy_kwargs": policy_args,
    }
    alg_kwargs = {**alg_kwargs, **alg_config.get("args")}
    wrapper = WrapperRegistry.get(alg_config["id"])
    alg = wrapper(alg_cls, alg_kwargs,
                  alg_config["kl_div"]["coeff"], tracker,
                  alg_config["kl_div"].get("target_kl", None),
                  alg_config["kl_div"].get("norm_reward", False))
    alg.load_from_dict(alg_state)
    return alg


class OnPolicyTrainer(TrainerWarmStartMixin):
    """
    A generic trainer for training LMs with onpolicy algorithms from SB3
    """

    def __init__(self,
                 tokenizer_config: Dict[str, Any],
                 datapool_config: Dict[str, Any],
                 reward_config: Dict[str, Any],
                 env_config: Dict[str, Any],
                 on_policy_alg_config: Dict[str, Any],
                 train_eval_config: Dict[str, Any],
                 tracker: Tracker = None,
                 experiment_name: str = ''
                 ):
        self._tokenizer_config = tokenizer_config
        self._datapool_config = datapool_config
        self._reward_config = reward_config
        self._env_config = env_config
        self._on_policy_alg_config = on_policy_alg_config
        self._train_eval_config = train_eval_config
        self._tracker = tracker
        self._experiment_name = experiment_name
        self._setup()

    def _setup(self):
        # load trainer state from available previous checkpoint if available
        self.load_trainer_state(self._tracker)

        # build components
        self._tokenizer = build_tokenizer(self._tokenizer_config)
        self._reward_fn = build_reward_fn(self._reward_config)
        # self._metrics = build_metrics(
        #     self._train_eval_config.get("metrics", []))
        self._samples_by_split = build_datapool(
            self._datapool_config)
        self._env = build_env(self._env_config, self._reward_fn,
                              self._tokenizer, self._samples_by_split["train"])
        torch.cuda.empty_cache()
        self._alg = build_alg(self._on_policy_alg_config,
                              self._env, self._tracker,
                              self._policy_state_dict,
                              self._alg_state_dict)
        self._metrics = build_metrics(
            self._train_eval_config.get("metrics", []))

        # extract train params
        self._max_episode_length = self._env_config["args"]["max_episode_length"]
        self._max_prompt_length = self._env_config["args"]["max_prompt_length"]
        self._eval_batch_size = self._train_eval_config["eval_batch_size"]
        self._n_iters = int(self._train_eval_config["n_iters"])
        self._n_steps_per_iter = self._env.num_envs * self._alg.n_steps

        # gen kwargs for evaluation (if it is different from rollout gen kwargs)
        self._eval_gen_kwargs = self._train_eval_config.get(
            "generation_kwargs", None)
        
    def _evaluate_on_datapools(self, epoch: int,
                               splits: List[str] = ["val", "test"], samples_to_eval=None, return_predictions_dict=False, prev_success_dict=None, prev_failure_dict=None, reward_fn=None, not_run_gpt_evaluation=False, return_reward=False, run_no_focus_area=False, run_gold_focus_area=False, run_curriculum_learning=False, curr_curriculum=None, openai_url=None):
        
        for split in splits:
            if return_predictions_dict:
                success_predictions_dict, failure_predictions_dict = evaluate_on_samples(policy=self._alg.policy,
                                    tokenizer=self._tokenizer,
                                    samples=self._samples_by_split[split],
                                    batch_size=self._eval_batch_size,
                                    max_prompt_length=self._max_prompt_length,
                                    metrics=self._metrics,
                                    epoch=epoch,
                                    split_name=split,
                                    tracker=self._tracker,
                                    gen_kwargs=self._eval_gen_kwargs, samples_to_eval=samples_to_eval, return_predictions_dict=return_predictions_dict, prev_success_dict=prev_success_dict, prev_failure_dict=prev_failure_dict, reward_fn=reward_fn, openai_url=openai_url, not_run_gpt_evaluation=not_run_gpt_evaluation, running_evaluation=True, run_no_focus_area=run_no_focus_area, run_gold_focus_area=False, run_curriculum_learning=run_curriculum_learning, curr_curriculum=curr_curriculum)
                return success_predictions_dict, failure_predictions_dict
            else:
                if return_reward:
                    curr_reward = evaluate_on_samples(policy=self._alg.policy,
                                        tokenizer=self._tokenizer,
                                        samples=self._samples_by_split[split],
                                        batch_size=self._eval_batch_size,
                                        max_prompt_length=self._max_prompt_length,
                                        metrics=self._metrics,
                                        epoch=epoch,
                                        split_name=split,
                                        tracker=self._tracker,
                                        gen_kwargs=self._eval_gen_kwargs, samples_to_eval=samples_to_eval, return_predictions_dict=return_predictions_dict, reward_fn=reward_fn, openai_url=openai_url, not_run_gpt_evaluation=not_run_gpt_evaluation, return_reward=return_reward, run_curriculum_learning=run_curriculum_learning, curr_curriculum=curr_curriculum)
                else:
                    curr_reward = evaluate_on_samples(policy=self._alg.policy,
                                    tokenizer=self._tokenizer,
                                    samples=self._samples_by_split[split],
                                    batch_size=self._eval_batch_size,
                                    max_prompt_length=self._max_prompt_length,
                                    metrics=self._metrics,
                                    epoch=epoch,
                                    split_name=split,
                                    tracker=self._tracker,
                                    gen_kwargs=self._eval_gen_kwargs, samples_to_eval=samples_to_eval, return_predictions_dict=return_predictions_dict, reward_fn=reward_fn, openai_url=openai_url, not_run_gpt_evaluation=not_run_gpt_evaluation, return_reward=return_reward, run_curriculum_learning=run_curriculum_learning, curr_curriculum=curr_curriculum)
        if return_reward:
            return curr_reward
                
    
    def train_and_eval(self, only_run_eval=False, continue_rl_training=False, epoch_to_start_training=None, reward_fn=None, not_run_gpt_evaluation=None, just_eval=False, should_save=False, run_curriculum_learning=False, run_supervised_training_in_between=False, iters_to_run=None, args=None, openai_url=None):
        # evaluate on val and test set before fine-tuning once
        
        new_curriculum = 0
        curr_curriculum = 0
        if not run_curriculum_learning:
            curr_curriculum = None
            new_curriculum = None
        self._env.env_method("update_curriculum", new_curriculum)
        print("curr_curriculum is " + str(curr_curriculum))
        print("openai_url is being passed to train_and_eval" + str(openai_url))
        
        
        iter_start = self._trainer_state["current_iter"]
        num_samples_to_eval = 100
        test_num_samples_to_eval = 5000
        print("Going to start the training by evaluation")
        sys.stdout.flush()
        if continue_rl_training:
            iter_start = epoch_to_start_training + 1
            prev_success_dict, prev_failure_dict = {}, {}
            if just_eval:
                prev_success_dict, prev_failure_dict = self._evaluate_on_datapools(epoch=-1, splits=["test"], samples_to_eval=test_num_samples_to_eval, return_predictions_dict=True, reward_fn=reward_fn, openai_url=openai_url, not_run_gpt_evaluation=not_run_gpt_evaluation, run_curriculum_learning=run_curriculum_learning, curr_curriculum=curr_curriculum)
        else:
            # self._evaluate_on_datapools(epoch=0, splits=["val"], not_run_gpt_evaluation=not_run_gpt_evaluation)
            if just_eval and not args.save_predicted_comms:
                prev_success_dict, prev_failure_dict = self._evaluate_on_datapools(epoch=-1, splits=["test"], samples_to_eval=test_num_samples_to_eval, return_predictions_dict=True, reward_fn=reward_fn, openai_url=openai_url, not_run_gpt_evaluation=not_run_gpt_evaluation, run_curriculum_learning=run_curriculum_learning, curr_curriculum=curr_curriculum, run_no_focus_area=args.eval_no_focus_area, run_gold_focus_area=args.eval_gold_focus_area)
                print("We now have these many prev_success_dict " + str(len(prev_success_dict)) + " and these many prev_failure_dict " + str(len(prev_failure_dict)) + " at the start")
                exit(1)
            else:
                print("Evaluating on the test set")
                # and keep track of what succeeded/failed
                prev_success_dict, prev_failure_dict = self._evaluate_on_datapools(epoch=-1, splits=["test"], samples_to_eval=test_num_samples_to_eval, return_predictions_dict=True, reward_fn=reward_fn, openai_url=openai_url, not_run_gpt_evaluation=not_run_gpt_evaluation, run_no_focus_area=args.eval_no_focus_area, run_curriculum_learning=run_curriculum_learning, curr_curriculum=curr_curriculum, run_gold_focus_area=args.eval_gold_focus_area)
                
        
        if iters_to_run is None:
            iters_to_run = self._n_iters
        best_validation_reward = 0.0
        prev_success_dict, prev_failure_dict = {}, {}
   
        print("run_curriculum_learning is " + str(run_curriculum_learning))
        
        # train for given number of iters
        # for curriculum learning
        iterations_without_improvement = 0
        if run_supervised_training_in_between:
            dict_to_store_supervised_training_ex = {}
        for epoch in range(iter_start, iters_to_run):
            
            self._env.env_method("update_epoch", epoch)
            
            # current state
            self._trainer_state["current_iter"] = epoch
            self._alg.learn(self._n_steps_per_iter)
            print("Done training the model at RL epoch " + str(epoch))

            # always eval the validation set
            evaluate_because_best_validation = False
            if not just_eval:
                # if just evaluating, run only the test set 
                curr_validation_reward = self._evaluate_on_datapools(epoch=epoch, splits=["val"], samples_to_eval=num_samples_to_eval, reward_fn=reward_fn, openai_url=openai_url,  return_reward=True, run_curriculum_learning=run_curriculum_learning, curr_curriculum=curr_curriculum)
                print("curr_validation_reward is " + str(curr_validation_reward) + " at epoch " + str(epoch) + " and best_validation_reward is " + str(best_validation_reward))
                if (curr_validation_reward > best_validation_reward and not just_eval):
                    print("Saving the model as we got a new best validation reward at epoch " + str(epoch))
                    best_validation_reward = curr_validation_reward
                    # evaluate test set 
                    evaluate_because_best_validation = True
                    
                    if should_save:
                        self.save_trainer_state(
                            self._tracker, self._alg.policy, self._trainer_state, given_checkpoint_id=epoch, remove_previous=False)
                        self._tracker.save_auto_model(self._alg.policy.get_language_model(), epoch=epoch, remove_previous=False)
                        print("Done saving the model")
                        sys.stdout.flush()
                else:
                    # we didn't have an increase in validation performance, so increase the iterations_without_improvement
                    iterations_without_improvement += 1
                    print("Validation performance didn't increase and it is now " + str(iterations_without_improvement) + " at epoch " + str(epoch))
                    
            # for curriculum learning, if we haven't increased in a while, update the counter for the reward function
            if iterations_without_improvement >= 3 and run_curriculum_learning:
                print("Time to update curriculum")
                new_curriculum = new_curriculum + 1
                print("At epoch " + str(epoch) + " we are updating the curriculum to " + str(new_curriculum))
                self._env.env_method("update_curriculum", new_curriculum)
                iterations_without_improvement = 0
                curr_curriculum = new_curriculum
            
            
            if evaluate_because_best_validation:
                # got a new best validation performance, ealuate the model
                prev_success_dict, prev_failure_dict = self._evaluate_on_datapools(epoch=epoch, splits=["test"], samples_to_eval=test_num_samples_to_eval, return_predictions_dict=True, prev_success_dict=prev_success_dict, prev_failure_dict=prev_failure_dict, reward_fn=reward_fn, openai_url=openai_url, not_run_gpt_evaluation=not_run_gpt_evaluation, run_curriculum_learning=run_curriculum_learning, curr_curriculum=curr_curriculum)
                if just_eval:
                    # after evaluating the test set, exit
                    print("Exiting because just_eval")
                    exit(1)
                self._evaluate_on_datapools(epoch=epoch, splits=["train"], samples_to_eval=50, reward_fn=reward_fn, openai_url=openai_url, not_run_gpt_evaluation=not_run_gpt_evaluation, run_curriculum_learning=run_curriculum_learning, curr_curriculum=curr_curriculum)
                
            if epoch % 10 == 0 and epoch != 0:
                # every 10 iterations, eval the training set just to see how things are progressing
                self._evaluate_on_datapools(epoch=epoch, splits=["train"], samples_to_eval=50, reward_fn=reward_fn, openai_url=openai_url, not_run_gpt_evaluation=not_run_gpt_evaluation, run_curriculum_learning=run_curriculum_learning, curr_curriculum=curr_curriculum)

            if just_eval:
                print("Exiting because just_eval")
                exit(1)
          
class SaveModelCallback(TrainerCallback):
    # callback to save the supervised learning model randomly throughout training. We don't use this, but it's useful as a debugging tool.
    def __init__(self, my_model_instance):
        self.my_model_instance = my_model_instance
    
    def on_epoch_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
        # Call your function to save the model at the end of each epoch
        self.save_model_function(model, state.epoch)

    def save_model_function(self, model, epoch):
        print("Calling save_model_function")
        # Replace this with your function to save the model
        should_save = False
        if epoch == 0:
            should_save = True 
        if epoch <= 100 and epoch % 20 == 0:
            should_save = True 
        if epoch > 0 and epoch % 50 == 0:
            should_save = True
            
        if should_save:
            if self.my_model_instance._tracker is not None:
                self.my_model_instance._tracker.save_auto_model(self.my_model_instance._model, epoch=epoch)

class SupervisedTrainer:
    """
    A supervised trainer to train LMs (causal and seq2seq) on text generation tasks (wrapper on HF trainer)
    """

    def __init__(self,
                 tokenizer_config: Dict[str, Any],
                 datapool_config: Dict[str, Any],
                 train_eval_config: Dict[str, Any],
                 alg_config: Dict[str, Any],
                 tracker: Tracker = None,
                 build_samples_from_given_samples=False,
                 samples_dict_to_train=None, 
                 args=None,
                 ):
        self._tokenizer_config = tokenizer_config
        self._datapool_config = datapool_config
        self._train_eval_config = train_eval_config
        self._alg_config = alg_config
        self._tracker = tracker
        self.build_samples_from_given_samples = build_samples_from_given_samples
        self.samples_dict_to_train = samples_dict_to_train
        self._setup(args)

    def _evaluate_on_datapools(self, epoch: int,
                               splits: List[str] = ["val", "test"], samples_to_eval=None):
        for split in splits:
            evaluate_supervised(model=self._model,
                                tokenizer=self._tokenizer,
                                samples=self._samples_by_split[split],
                                batch_size=self._eval_batch_size,
                                max_prompt_length=self._max_prompt_length,
                                metrics_config_dict=self._metrics_config_dict,
                                epoch=epoch,
                                split_name=split,
                                tracker=self._tracker,
                                generation_kwargs=self._gen_kwargs, 
                                samples_to_eval=samples_to_eval
                                )
            
    def _evaluate_on_and_return(self, epoch: int, splits: List[str] = ["val", "test"], reddit_data=False):
        for split in splits:
            if split == 'train':
                need_to_modify_samples = True 
            else:
                need_to_modify_samples = False
            all_ref_texts, all_generated_text, all_prompts, all_entities, all_comms_1, all_comms_2  = evaluate_supervised(model=self._model,
                                    tokenizer=self._tokenizer,
                                    samples=self._samples_by_split[split],
                                    batch_size=self._eval_batch_size,
                                    max_prompt_length=self._max_prompt_length,
                                    metrics_config_dict=self._metrics_config_dict,
                                    epoch=epoch,
                                    split_name=splits[0],
                                    tracker=self._tracker,
                                    generation_kwargs=self._gen_kwargs,
                                    return_data=True, 
                                    reddit_data=reddit_data, 
                                    need_to_modify_samples=need_to_modify_samples
                                    )
        
        return all_ref_texts, all_generated_text, all_prompts, all_entities, all_comms_1, all_comms_2     
        

    def _setup(self, args):
        self._tokenizer = build_tokenizer(self._tokenizer_config)
        self._metrics_config_dict = self._train_eval_config.get("metrics")
        self._samples_by_split = build_datapool(
            self._datapool_config)
        
        if self.build_samples_from_given_samples:
            # we are given samples, use those to build the datapool for training
            self._samples_by_split = build_datapool_given_samples(self.samples_dict_to_train)
        else:
            self._samples_by_split = build_datapool(
                self._datapool_config)
        
        self._train_dataset = get_datasets_for_causal(
            self._samples_by_split["train"]) if self._alg_config[
            "model_type"] == "causal" else get_datasets_for_seq2seq(self._samples_by_split["train"])
        preprocess_fn = tokenize_causal if self._alg_config[
            "model_type"] == "causal" else tokenize_seq2seq
        preprocess_fn = partial(preprocess_fn, tokenizer=self._tokenizer)
        self._tokenized_dataset = self._train_dataset.map(
            preprocess_fn, batched=True,
            remove_columns=self._train_dataset.column_names)
        model_cls = AutoModelForCausalLM if self._alg_config[
            "model_type"] == "causal" else AutoModelForSeq2SeqLM
        self._gen_kwargs = self._alg_config["generation_kwargs"]
        self._model_epoch_to_load = self._alg_config["model_epoch"]
        if self._model_epoch_to_load is not None:
            if args.not_run_bfloat16:
                print("Not signifying bfloat16")
                self._model = model_cls.from_pretrained(self._alg_config["model_name"] + str('_') + str(float(self._model_epoch_to_load)))
            else:
                self._model = model_cls.from_pretrained(self._alg_config["model_name"], torch_dtype=torch.bfloat16, device_map='balanced')

        else:
            if args.not_run_bfloat16:
                self._model = model_cls.from_pretrained(self._alg_config["model_name"], device_map='balanced')
            else:
                self._model = model_cls.from_pretrained(self._alg_config["model_name"], torch_dtype=torch.bfloat16, device_map='balanced')
            
        self._eval_batch_size = self._train_eval_config["eval_batch_size"]

        # setting max prompt length
        self._max_prompt_length = self._tokenizer_config.get(
            "max_length",  self._tokenizer.model_max_length)

        if (self._alg_config["model_type"] == "causal") and ((self._max_prompt_length + self._gen_kwargs["max_new_tokens"]) > self._tokenizer.model_max_length):
            self._max_prompt_length = self._max_prompt_length - \
                self._gen_kwargs["max_new_tokens"]

        self._eval_callback = EvalCallack(self._samples_by_split["val"],
                                          self._gen_kwargs,
                                          self._eval_batch_size,
                                          self._tokenizer,
                                          self._metrics_config_dict,
                                          self._max_prompt_length,
                                          self._tracker)
        train_args = self._alg_config["training_args"]
        train_args["output_dir"] = self._tracker.checkpoint_base_path
        train_args["seed"] = np.random.randint(1e+2)  # random seed
        if self._alg_config["run_pft"] == True:
            print("Setting ddp_find_unused_parameters False for peft")
            train_args["ddp_find_unused_parameters"] = False
        self._train_args = TrainingArguments(**train_args)
        data_collator = DataCollatorForLanguageModeling(self._tokenizer, mlm=False) if self._alg_config[
            "model_type"] == "causal" else DataCollatorForSeq2Seq(self._tokenizer, self._model)
        self._trainer = Trainer(model=self._model,
                                tokenizer=self._tokenizer,
                                args=self._train_args,
                                data_collator=data_collator,
                                train_dataset=self._tokenized_dataset,
                                callbacks=[self._eval_callback, SaveModelCallback(self)])

    def train_and_eval(self, only_run_eval=False, specific_epochs_to_eval=None, splits_to_eval=None, continue_rl_training=None, epoch_to_start_training=None, reward_fn=None, not_run_gpt_evaluation=False, just_eval=False, should_save=False, run_curriculum_learning=False, run_supervised_training_in_between=False, args=None, iters_to_run=None, openai_url=None):
        if not only_run_eval:
           
            # train using HF trainer
            print("Training the model now!")
            sys.stdout.flush()
            if args.not_run_bfloat16 and not args.just_eval:
                self._trainer.train()
            print("Done training the model")
            sys.stdout.flush()
            
            if self._tracker is not None:
                self._tracker.save_auto_model(
                    self._model)
            
            model_cls = AutoModelForCausalLM if self._alg_config["model_type"] == "causal" else AutoModelForSeq2SeqLM

            new_model_name = os.path.join(self._tracker._run_path, "model")
            
            if iters_to_run is not None:
                iterations_to_run = iters_to_run
            else:
                iterations_to_run = self._train_args.num_train_epochs
            for given_epoch in range(iterations_to_run):

                # we don't use this, but it helps to eval and debug the model
                should_eval_now = False
                if given_epoch == 0:
                    should_eval_now = True 
                if given_epoch <= 100 and given_epoch % 20 == 0:
                    should_eval_now = True 
                if given_epoch > 0 and given_epoch % 50 == 0:
                    should_eval_now = True
            
                if should_eval_now:
                    if not os.path.exists(new_model_name + str('_') + str(float(given_epoch))):
                        print("Path doesn't exist " + str(new_model_name + str('_') + str(float(given_epoch))))
                        continue
                    print("After training is completed, Loading model from " + str(new_model_name + str('_') + str(float(given_epoch))))
                    
                    
                    if args.not_run_bfloat16:
                        self._model = model_cls.from_pretrained(new_model_name + str('_') + str(float(given_epoch)), device_map='balanced')
                    else:
                        self._model = model_cls.from_pretrained(new_model_name + str('_') + str(float(given_epoch)), torch_dtype=torch.bfloat16, device_map='balanced')
                        
                    # self._model.parallelize()
                    self._evaluate_on_datapools(epoch=given_epoch)            

        else:
            if specific_epochs_to_eval is None:
                specific_epochs_to_eval = range(self._train_args.num_train_epochs)
            model_cls = AutoModelForCausalLM if self._alg_config["model_type"] == "causal" else AutoModelForSeq2SeqLM
            for given_epoch in specific_epochs_to_eval:
                
                if given_epoch == 0:
                    continue
                
                should_eval_now = False
                if given_epoch <= 100 and given_epoch % 10 == 0:
                    should_eval_now = True 
                elif given_epoch > 100 and given_epoch % 50 == 0:
                    should_eval_now = True
            
                if should_eval_now:
                
                    print("Loading model from " + str(self._alg_config["model_name"] + str('_') + str(float(given_epoch))))
                    if args.not_run_bfloat16:
                        self._model = model_cls.from_pretrained(self._alg_config["model_name"] + str('_') + str(float(given_epoch)), device_map='balanced')
                    else:
                        self._model = model_cls.from_pretrained(self._alg_config["model_name"] + str('_') + str(float(given_epoch)), torch_dtype=torch.bfloat16, device_map='balanced')
                    
                    self._evaluate_on_datapools(epoch=given_epoch)
