from typing import Any, Dict, List

from stable_baselines3.common.policies import BasePolicy
from tqdm import tqdm
from transformers import AutoTokenizer

from rl4lms.data_pools.custom_text_generation_pools import Sample
from rl4lms.envs.text_generation.logging_utils import Tracker
from rl4lms.envs.text_generation.metric import BaseMetric
import sys
import rl4lms_custom_functions
sys.path.append('../../')
sys.path.append('../')
import llm_prompts 
import jsonlines

def get_batch(samples: List[Sample], batch_size: int):
    current_ix = 0
    n_samples = len(samples)
    while current_ix < n_samples:
        current_batch = samples[current_ix : current_ix + batch_size]
        yield current_batch
        current_ix += batch_size

def evaluate_on_samples(
    policy: BasePolicy,
    tokenizer: AutoTokenizer,
    samples: List[Sample],
    batch_size: int,
    max_prompt_length: int,
    metrics: List[BaseMetric],
    epoch: int,
    split_name: str,
    tracker: Tracker = None,
    dt_control_token: str = "",
    gen_kwargs: Dict[str, Any] = None,
    samples_to_eval=None,
    return_predictions_dict=False,
    prev_success_dict=None,
    prev_failure_dict=None,
    reward_fn=None,
    not_run_gpt_evaluation=False,
    return_reward=False,
    running_evaluation=False,
    run_no_focus_area=False,
    run_gold_focus_area=False,
    run_curriculum_learning=False,
    curr_curriculum=0,
    openai_url=None,
):
    # generate text by batch
    all_generated_texts = []
    all_ref_texts = []
    all_prompt_texts = []
    all_meta_infos = []
    all_comm_1, all_comm_2, all_entities_discussed = [], [], []
    n_samples = len(samples)
    samples_count = 0
    for batch in tqdm(list(get_batch(samples, batch_size)), desc="Evaluating"):
        if samples_to_eval is not None and samples_count > samples_to_eval:
            # we evaluated enough samples
            continue
        samples_count += len(batch)
        if split_name == 'train':
            batch = [x[0] for x in batch]
        batch_generated_texts = generate_text(
            policy, tokenizer, batch, max_prompt_length, dt_control_token, gen_kwargs
        )
        batch_ref_texts = [sample.references for sample in batch]
        batch_prompt_texts = [sample.prompt_or_input_text for sample in batch]
        batch_meta_infos = [sample.meta_data for sample in batch]
        all_generated_texts.extend(batch_generated_texts)
        all_ref_texts.extend(batch_ref_texts)
        all_prompt_texts.extend(batch_prompt_texts)
        all_meta_infos.extend(batch_meta_infos)
        all_comm_1.extend([sample.comm_1 for sample in batch])
        all_comm_2.extend([sample.comm_2 for sample in batch])
        all_entities_discussed.extend([sample.entity for sample in batch])
        
    print("Generated text in evaluation")
    sys.stdout.flush()

    # compute metrics
    corpus_level_metrics = {}
    sample_scores_by_metric = {}
    no_change_predictions_dict = None
    if metrics is not None:
        for metric in metrics:
            if isinstance(metric, rl4lms_custom_functions.CommunityAnalysisMetric):
                print("Computing reward for CommunityAnalysisMetric and openai_url is " + str(openai_url))
                sys.stdout.flush()

                
                metric_dict, success_predictions_dict, failure_predictions_dict = metric.compute(
                    all_prompt_texts,
                    all_generated_texts,
                    all_ref_texts,
                    all_meta_infos,
                    policy.get_language_model(),
                    split_name,
                    gold_comm_1=all_comm_1,
                    gold_comm_2=all_comm_2,
                    entities_discussed=all_entities_discussed,
                    return_success_failure=True,
                    evaluating_test=True,
                    reward_fn=reward_fn, 
                    not_run_gpt_evaluation=not_run_gpt_evaluation,
                    running_evaluation=running_evaluation,
                    run_no_focus_area=run_no_focus_area,
                    run_gold_focus_area=run_gold_focus_area,
                    run_curriculum_reward=run_curriculum_learning,
                    curr_curriculum=curr_curriculum, 
                    openai_url=openai_url
                )
                
                no_change_predictions_dict = None
                if prev_success_dict is not None:
                    prev_pred_dict = prev_success_dict.copy()
                    prev_pred_dict.update(prev_failure_dict)
                    curr_pred_dict = success_predictions_dict.copy()
                    curr_pred_dict.update(failure_predictions_dict)
                    old_success_predictions_dict = success_predictions_dict.copy()
                    old_failure_predictions_dict = failure_predictions_dict.copy()
                    success_predictions_dict, failure_predictions_dict, no_change_predictions_dict = {}, {}, {}
                    for given_input, predictions_tuple in curr_pred_dict.items():
                        if given_input not in prev_pred_dict:
                            if given_input in old_success_predictions_dict:
                                success_predictions_dict[given_input] = old_success_predictions_dict[given_input]
                            if given_input in old_failure_predictions_dict:
                                failure_predictions_dict[given_input] = old_failure_predictions_dict[given_input]
                        else:
                            # save the old prediction
                            all_prev_purity_list = predictions_tuple[-2]
                            # Store the old purity
                            all_prev_purity_list.append(float(prev_pred_dict[given_input][-1]))
                            
                            # append the old focus area to the running list of focus areas
                            all_prev_focus_area_list = prev_pred_dict[given_input][-3]
                            all_prev_focus_area_list.append(str(prev_pred_dict[given_input][2]))
                            # Create a new tuple based on predictions_tuple with the updated value
                            new_predictions_tuple = predictions_tuple[:-3] + (all_prev_focus_area_list,) + (all_prev_purity_list,) + predictions_tuple[-1:]
                            
                            if float(prev_pred_dict[given_input][-1]) > float(curr_pred_dict[given_input][-1]):
                                # the purity is higher before, failure
                                # replace with the new purity
                                new_predictions_tuple = prev_pred_dict[given_input]
                                failure_predictions_dict[given_input] = new_predictions_tuple
                                # print("Failure dict was updated!")
                            elif prev_pred_dict[given_input][-1] < curr_pred_dict[given_input][-1]:
                                # the purity is higher now, success
                                success_predictions_dict[given_input] = new_predictions_tuple
                            elif prev_pred_dict[given_input][-1] == curr_pred_dict[given_input][-1]:
                                # the purity is the same
                                no_change_predictions_dict[given_input] = new_predictions_tuple
                    
                
                tracker.log_metrics(epoch, split_name, metric_dict, logging_reward=True, samples_logged=samples_count)
                # log the success and failure predictions
                tracker.log_success_failure_predictions(success_predictions_dict=success_predictions_dict, failure_predictions_dict=failure_predictions_dict, epoch_to_log=epoch, split_name=split_name, no_change_predictions_dict=no_change_predictions_dict)
                
            else:
                metric_dict = metric.compute(
                    all_prompt_texts,
                    all_generated_texts,
                    all_ref_texts,
                    all_meta_infos,
                    policy.get_language_model(),
                    split_name,
                )

            for metric_key, (sample_scores, corpus_score) in metric_dict.items():
                if sample_scores is None:
                    sample_scores = ["n/a"] * n_samples
                corpus_level_metrics[metric_key] = corpus_score
                sample_scores_by_metric[metric_key] = sample_scores
            # print("Computing normal metrics corpus_level_metrics in evaluation_utils.py " + str(corpus_level_metrics))

    if split_name == 'train':
        samples = [x[0] for x in samples]

    # aggregate sample metric scores
    sample_predictions_dict = []
    for ix, (sample, prompt_text, generated_text, ref_texts) in enumerate(
        zip(samples, all_prompt_texts, all_generated_texts, all_ref_texts)
    ):
        sample_prediction = {
            "split_name": split_name,
            "sample_id": sample.id,
            "prompt_text": prompt_text,
            "generated_text": generated_text,
            "ref_text": "".join(
                [
                    f"<START-{ref_ix+1}>" + ref_text + f"<END-{ref_ix+1}>"
                    for ref_ix, ref_text in enumerate(ref_texts)
                ]
            ),
        }
        for metric_key, sample_scores in sample_scores_by_metric.items():
            sample_prediction[metric_key] = sample_scores[ix]
        sample_predictions_dict.append(sample_prediction)

    if tracker is not None:
        # log the entire predictions
        tracker.log_predictions(epoch, split_name, sample_predictions_dict)
        
        # log the corpus level scores
        tracker.log_metrics(epoch, split_name, corpus_level_metrics)
        
    if return_reward:
        if 'custom_metrics/community_metric_purity_rewards_all_match' in corpus_level_metrics:
            return corpus_level_metrics['custom_metrics/community_metric_purity_rewards_all_match']
        elif 'community_metric_purity_rewards_all_match' in corpus_level_metrics:
            return corpus_level_metrics['community_metric_purity_rewards_all_match']
        else:
            print("Problem, no metric to return! " + str(corpus_level_metrics))
            exit(1)
    
        
    if return_predictions_dict:
        if no_change_predictions_dict is not None:
            # for next time, consider no_change_predictions_dict as failures
            failure_predictions_dict.update(no_change_predictions_dict)
        return success_predictions_dict, failure_predictions_dict


def generate_text(
    policy: BasePolicy,
    tokenizer: AutoTokenizer,
    samples: List[Sample],
    max_prompt_length: int,
    dt_control_token: str,
    gen_kwargs: Dict[str, Any],
):
    
    prompt_texts = [
        dt_control_token + sample.prompt_or_input_text for sample in samples
    ]
    generated_texts = policy.generate(
        tokenizer, prompt_texts, max_prompt_length, gen_kwargs=gen_kwargs
    ).gen_texts
    return generated_texts
