import json
from dataclasses import asdict
from datetime import datetime, timedelta
from pathlib import Path
from typing import Union, Dict, Any
from tqdm import tqdm

import os
import numpy as np
from torch.utils.data import DataLoader

from common.loggable import Loggable
from common.utils import EalfDataset, ealf_collate_fn, EalfDataSample
from config.args import (
    BaseDataArguments,
    BaseModelArguments,
    BaseTrainingArguments
)
from models.base import BasePolicy
from models.mm_student import SkillDecoder


class EalfTrainer(Loggable):
    def __init__(
        self,
        agent: Union[BasePolicy, SkillDecoder],
        data_args: BaseDataArguments,
        model_args: BaseModelArguments,
        train_args: BaseTrainingArguments,
    ):
        super().__init__(cfg=train_args.wandb_cfg)
        self.agent = agent
        self.data_args = data_args
        self.model_args = model_args
        self.train_args = train_args

        self.cfgs = None  # type: Dict[str, Union[BaseDataArguments, BaseModelArguments, BaseTrainingArguments]]
        self.train_dataloader = None  # type: DataLoader
        self.eval_dataloader = None  # type: DataLoader

        if self.train_args.evaluation_strategy == "steps":
            if (self.train_args.eval_steps % self.train_args.logging_steps) != 0:
                raise RuntimeError("Evaluation steps should be divied by the logging steps")

        self.do_eval = False

        self.today = None
        self.start = None
        self.date_prefix_save_path = None
        self.n_update = 0

        self.required_total_update = None
        self.setup_learn()

        if self.train_args.training_mode == "imitation":
            self.update_fn = self.imitation_update
        elif self.train_args.training_mode == "rl":
            self.update_fn = self.rl_update
        elif self.train_args.training_mode == "irl":
            self.update_fn = self.irl_update
        else:
            raise NotImplementedError(f"Undefined training mode: {self.train_args.training_mode}.")

    def dump_logs(self, step: int):
        now = datetime.now()
        elapsed = max((now - self.start).seconds, 1)
        fps = step / elapsed
        remain = int((self.required_total_update - step) / fps)
        eta = now + timedelta(seconds=remain)

        self.record({
            "info/suffix": self.train_args.output_filename,
            "time/fps": fps,
            "time/elapsed": str(timedelta(seconds=elapsed)),
            "time/remain": str(timedelta(seconds=remain)),
            "time/eta": eta.strftime("%m.%d / %H:%M:%S")
        })
        super().dump_logs(step=step)

    def build_dataset(self):
        train_dataset = EalfDataset(
            dataset_path=self.data_args.train_dataset_path,
            instruction_path=self.data_args.train_instruction_path,
            prompt=self.model_args.multimodal_prompt,
            reward_path=self.data_args.reward_path
        )
        train_dataloader = DataLoader(
            dataset=train_dataset,
            batch_size=int(self.train_args.per_device_train_batch_size),
            collate_fn=ealf_collate_fn,
            shuffle=True
        )
        self.train_dataloader = train_dataloader

        n_updates_per_dataset = len(train_dataset) // int(self.train_args.per_device_train_batch_size)
        self.required_total_update = self.train_args.num_train_epochs * n_updates_per_dataset

        if (self.data_args.eval_dataset_path is not None) and (self.data_args.eval_instruction_path is not None):
            eval_dataset = EalfDataset(
                dataset_path=self.data_args.eval_dataset_path,
                instruction_path=self.data_args.eval_instruction_path,
                prompt=self.model_args.multimodal_prompt,
                num_data_limit=self.data_args.num_eval_data_limit,
                for_eval=True
            )
            eval_dataloader = DataLoader(
                dataset=eval_dataset,
                batch_size=int(self.train_args.per_device_eval_batch_size),
                collate_fn=ealf_collate_fn,
                shuffle=True
            )

            self.eval_dataloader = eval_dataloader
            self.do_eval = True

    def setup_learn(self):
        save_dir = self.train_args.output_dir
        save_filename = self.train_args.output_filename

        self.today = datetime.today()
        today_str = self.today.strftime('%Y-%m-%d')  
        date_prefix = Path(save_dir) / Path(today_str)  

        self.date_prefix_save_path = date_prefix

        cfg_prefix = (Path(date_prefix) / Path("cfg"))
        cfg_prefix.mkdir(parents=True, exist_ok=True)   

        # Save configuration files
        with open(str(cfg_prefix / Path(f"cfg_{save_filename}")), "w") as fp:
            cfgs = {
                "data_args": asdict(self.data_args),
                "model_args": asdict(self.model_args),
                "train_args": asdict(self.train_args)
            }
            json.dump(cfgs, fp)
            super().init_wandb(program_cfg=cfgs)

        self.start = datetime.now()
        self.build_dataset()

    def imitation_update(self, data: EalfDataSample) -> Dict[str, Any]:
        info = self.agent.imitation_update(
            prompts=data.prompts,
            captions=data.captions,
            label=data.actions
        )
        return info


    def rl_update(self, data: EalfDataSample) -> Dict[str, Any]:
        info = self.agent.rl_update(
            prompts=data.prompts,
            next_prompts=data.next_prompts,
            captions = data.captions,
            next_captions = data.next_captions,
            actions=data.actions,
            rewards=data.rewards,
            dones=data.dones,
            deterministic=False
        )
        return info


    def irl_update(self, data: EalfDataSample) -> Dict[str, Any]:
        info = self.agent.irl_update(
            prompts=data.prompts,
            image_observations=data.image_observations,
            actions=data.actions,
            next_image_observations=data.next_image_observations,
            dones=data.dones,
            deterministic=False
        )
        return info

    

    def evaluate(self, n_update) -> Dict[str, Any]:

        accuracy = 0.0
        predictions = []
        labels = []
        success_rate = 0
        cgc=0
        planning=0
        matched = 0
        n_data = 0
        for data in self.eval_dataloader:
            data: EalfDataSample

            prediction = self.agent.predict(
                prompts=data.prompts,
                captions=data.captions,
                deterministic=True,
                # greedy_selection=True
            )

            predictions.extend(prediction.tolist())
            labels.extend(data.actions.tolist())
            accuracy += np.sum(prediction == data.actions) / len(prediction)
            matched += np.sum(prediction == data.actions)
            n_data += len(prediction)

        eval_dict = {
            "accuracy": (matched / n_data) * 100,
            "predictions": predictions,
            "labels": labels,
            "success_rate": success_rate,
            "cgc": cgc,
            "plan": planning
        }
        
        return eval_dict

    def run(self):
        best = 0
        for epoch in tqdm(range(self.train_args.num_train_epochs), desc="Epoch"):
            for data in tqdm(self.train_dataloader, desc="Mini Batch"):
                update_info = self.update_fn(data)
                self.n_update += 1
                self.record_from_dicts(update_info, mode="train")

                if (self.n_update % self.train_args.logging_steps) == 0:
                    if self.do_eval and self.train_args.evaluation_strategy == "steps":
                        if (self.n_update % self.train_args.eval_steps) == 0:
                            eval_info = self.evaluate(self.n_update)
                            self.record_from_dicts(eval_info, mode="eval")
                            if best < eval_info["accuracy"]:
                                best = eval_info["accuracy"]
                                self.agent.save(os.path.join(self.date_prefix_save_path, f"best"))

                    self.dump_logs(step=self.n_update)
                    

                if (self.train_args.save_strategy == "steps") and ((self.n_update % self.train_args.save_steps) == 0):
                    if self.n_update >=130000:
                        self.agent.save(os.path.join(self.date_prefix_save_path, f"ckpt_step_{self.n_update}"))

            if self.do_eval and self.train_args.evaluation_strategy == "epoch":
                eval_info = self.evaluate()
                self.record_from_dicts(eval_info, mode="eval")
                self.dump_logs(step=self.n_update)

            if self.train_args.save_strategy == "epoch":
                self.agent.save(os.path.join(self.date_prefix_save_path, f"ckpt_epoch_{epoch}"))
