import json
from dataclasses import asdict
from datetime import datetime, timedelta
from pathlib import Path
from typing import Union, Dict, Any
from tqdm import tqdm
import pickle
import os
import numpy as np
from torch.utils.data import DataLoader

from common.loggable import Loggable
from common.utils import EoEDataset, EnsembleSample, TestDataset, ensemble_collate_fn, test_collate_fn, TestSample
from config.args import (
    BaseTrainingArguments, BaseDataArguments
)
from model.base import BasePolicy
from model.model import Classifier
from common.invariants import max_epi_len
from sklearn.metrics import recall_score
from collections import Counter



class Trainer(Loggable):
    def __init__(
        self,
        classifier: Union[BasePolicy, Classifier],
        train_args: BaseTrainingArguments,
        model_args:  BaseTrainingArguments,
        data_args: BaseDataArguments,
        save_dir: None
    ):
        super().__init__(cfg=train_args.wandb_cfg)
        self.classifier = classifier
        self.train_args = train_args

        self.cfgs = None  
        self.train_dataloader = None  # type: DataLoader
        self.eval_dataloader = None  # type: DataLoader

        if self.train_args.evaluation_strategy == "steps":
            if (self.train_args.eval_steps % self.train_args.logging_steps) != 0:
                raise RuntimeError("Evaluation steps should be divied by the logging steps")

        self.do_eval = False
        self.data_args = data_args
        self.model_args = model_args
        self.train_args = train_args

        self.today = None
        self.start = None
        self.date_prefix_save_path = None
        self.n_update = 0

        self.required_total_update = None
        self.setup_learn()
        self.update_fn = self.update

        self.weights = []
        
        self.save_dir = save_dir

    def dump_logs(self, step: int):
        now = datetime.now()
        elapsed = max((now - self.start).seconds, 1)
        fps = step / elapsed
        remain = int((self.required_total_update - step) / fps)
        eta = now + timedelta(seconds=remain)

        self.record({
            "info/suffix": self.train_args.output_filename,
            "time/fps": fps,
            "time/elapsed": str(timedelta(seconds=elapsed)),
            "time/remain": str(timedelta(seconds=remain)),
            "time/eta": eta.strftime("%m.%d / %H:%M:%S")
        })
        super().dump_logs(step=step)

    def build_dataset(self):
        train_dataset = EoEDataset(
            dataset_path=self.data_args.train_dataset_path,
            instruction_path=self.data_args.train_instruction_path,
            prompt_format=self.data_args.prompt_format,
            temporal_reward_path=self.data_args.temporal_reward_path,
            relational_reward_path=self.data_args.relational_reward_path,
            procedure_reward_path=self.data_args.procedure_reward_path,
            expert_reward_path=self.data_args.expert_reward_path

            )
        train_dataloader = DataLoader(
            dataset=train_dataset,
            batch_size=int(self.train_args.per_device_train_batch_size),
            collate_fn=ensemble_collate_fn,
            shuffle=True
        )
        self.train_dataloader = train_dataloader

        n_updates_per_dataset = len(train_dataset) // int(self.train_args.per_device_train_batch_size)
        self.required_total_update = self.train_args.num_train_epochs * n_updates_per_dataset

        if (self.data_args.eval_dataset_path is not None) and (self.data_args.eval_instruction_path is not None):
            eval_dataset = TestDataset(
                dataset_path=self.data_args.eval_dataset_path,
                instruction_path=self.data_args.eval_instruction_path,
                prompt_format=self.data_args.prompt_format,
                temporal_reward_path=self.data_args.temporal_reward_path,
                relational_reward_path=self.data_args.relational_reward_path,
                procedure_reward_path=self.data_args.procedure_reward_path,
                expert_reward_path=self.data_args.expert_reward_path,
                num_data_limit=self.data_args.num_eval_data_limit,
                for_eval=True
            )
            eval_dataloader = DataLoader(
                dataset=eval_dataset,
                batch_size=int(self.train_args.per_device_eval_batch_size),
                collate_fn=test_collate_fn,
                shuffle=True
            )

            self.eval_dataloader = eval_dataloader
            self.do_eval = True

    def setup_learn(self):
        save_dir = self.train_args.output_dir
        save_filename = self.train_args.output_filename

        self.today = datetime.today()
        today_str = self.today.strftime('%Y-%m-%d')  
        date_prefix = Path(save_dir) / Path(today_str) 

        self.date_prefix_save_path = date_prefix

        cfg_prefix = (Path(date_prefix) / Path("cfg"))
        cfg_prefix.mkdir(parents=True, exist_ok=True) 

        # Save configuration files
        with open(str(cfg_prefix / Path(f"cfg_{save_filename}")), "w") as fp:
            cfgs = {
                "data_args": asdict(self.data_args),
                "model_args": asdict(self.model_args),
                "train_args": asdict(self.train_args)
            }
            json.dump(cfgs, fp)
            super().init_wandb(program_cfg=cfgs)

        self.start = datetime.now()
        self.build_dataset()

    def update(self, data: EnsembleSample) -> Dict[str, Any]:
        info = self.classifier.update(
            prompts=data.prompts,
            rewards=data.rewards,
            dummys = data.dummys,
            succs = data.succs,
            reward_masks=data.reward_masks
        )
        return info


    def evaluate(self, n_update) -> Dict[str, Any]:
        accuracy = 0.0
        predictions = []
        labels = []
        n_data = 0
        mean_rs = []
        
        #### save data ####
        task_ids = []
        historys = []
        instruction = []
        actions = [] 
        weights = []

        for data in self.eval_dataloader:
            data: TestSample

            prediction, weight  = self.classifier.predict(
                prompts=data.prompts,
                rewards=data.rewards, 
                deterministic=True,
            )  # (batch * max length, 1)
            prediction = prediction.reshape(-1)

            prediction = np.clip(prediction, -2, 2)

            
            predictions.extend(prediction.tolist())
            labels.extend(data.labels.tolist())

            accuracy += np.sum(np.round(predictions) == labels)
            n_data += len(labels)

      
        eval_dict = {
            "accuracy": (accuracy / n_data) * 100 if n_data > 0 else 0,
            "predictions": predictions,
            "labels": labels,

        }
            
        return eval_dict, weights

    def run(self):
        for epoch in tqdm(range(self.train_args.num_train_epochs), desc="Epoch"):
            for data in tqdm(self.train_dataloader, desc="Mini Batch"):
                update_info = self.update_fn(data)
                self.n_update += 1
                self.record_from_dicts(update_info, mode="train")

                if (self.n_update % self.train_args.logging_steps) == 0:
                    if self.do_eval and self.train_args.evaluation_strategy == "steps":
                        if (self.n_update % self.train_args.eval_steps) == 0:
                            eval_info, weights = self.evaluate(self.n_update)
                            self.weights.append(weights)                         
                            self.record_from_dicts(eval_info, mode="eval")
                    self.dump_logs(step=self.n_update)

                if (self.train_args.save_strategy == "steps") and ((self.n_update % self.train_args.save_steps) == 0):
                    self.classifier.save(os.path.join(self.date_prefix_save_path, f"ckpt_step_{self.n_update}"))

            if self.do_eval and self.train_args.evaluation_strategy == "epoch":
                eval_info = self.evaluate()
                self.record_from_dicts(eval_info, mode="eval")
                self.dump_logs(step=self.n_update)

            if self.train_args.save_strategy == "epoch":
                self.classifier.save(os.path.join(self.date_prefix_save_path, f"ckpt_epoch_{epoch}"))
        
