from rl4lms.data_pools.text_generation_pool import Sample, TextGenPool
from rl4lms.envs.text_generation.observation import Observation
from rl4lms.envs.text_generation.reward import RewardFunction
from rl4lms.envs.text_generation.metric import BaseMetric, RougeMetric
from typing import Dict, Any, List
from transformers import AutoTokenizer
import os
from transformers import PreTrainedModel
import json
import sys 
sys.path.append('..')
import llm_prompts
import llm_reward_fns
import numpy as np
import torch
import spacy
import flair, torch

# additional metrics to regularize the RL model
def rouge1_metric(pred: List[str], ref: List[List[str]], return_dict=False):
    res = RougeMetric().compute(
        prompt_texts=[], generated_texts=pred, reference_texts=ref
    )
    if return_dict:
        return {'rouge1': res["lexical/rouge_rouge1"][-1]}
    return res["lexical/rouge_rouge1"][-1]

def rouge_combined(pred: List[str], ref: List[List[str]]):

    rouge_keys = ["rouge1", "rouge2", "rougeL"]
    res = RougeMetric(use_single_ref=False).compute(
        prompt_texts=[], generated_texts=pred, reference_texts=ref
    )
    rouge_scores = [res["lexical/rouge_" + k][-1] for k in rouge_keys]
    scores = dict(zip(rouge_keys, rouge_scores))
    if len(rouge_scores) > 0:
        scores.update({"rouge_combined": np.mean(rouge_scores)})
    scores.update({"rouge_combined": 0})
    return scores



class CommunityAnalysisDataset(TextGenPool):
    # this is our custom dataset loading function. It will load the reddit datasets 
    @classmethod
    def prepare(cls, split: str, stage: str, prompt_prefix: str = "", truncate_article: int = None, max_size: int = None, data_samples=None, eval_gold_comms=False, path_to_save_dataset=None, path_to_save_test_dataset=None, samples_to_use=None, path_to_save_twitter_dataset=None, test_samples_to_use=None, train_on_test_set=False):
        # use stage if you want to split between supervised and non supervised
        '''simple data iterator given the community description prompt which is input to GPT and the output summary from GPT'''
        print("Current split is " + str(split))
        if split == 'test':
            # test set is different
            path_to_save_dataset = path_to_save_test_dataset
        if split == 'twitter':
            # we only test on twitter
            path_to_save_dataset = path_to_save_twitter_dataset
            split = 'test'
            print("Loading data from twitter " + str(path_to_save_dataset + str(split.replace("'", "")) + '.jsonl'))
        pth = path_to_save_dataset + str(split.replace("'", "")) + '.jsonl'
        data = []
        with open(pth, "r") as f:
            for line in f:
                data.append(json.loads(line))
       
        print("data length is " + str(len(data)))
       
        samples = []
        for ix, item in enumerate(data):
            
            if samples_to_use is not None and len(samples) > samples_to_use:
                # in case you don't want to run on everything, say if you want to test some reward functions 
                break
            if test_samples_to_use is not None and split == 'test' and len(samples) > test_samples_to_use:
                break
            if type(item["gold"]) != list:
                item["gold"] = [item["gold"]]
            else:
                item["gold"] = [str(item["gold"][0])]
                
                
            curr_entity = None 
            if "entity" in item:
                curr_entity = item["entity"]
                
            sample = Sample(
                id=f"{split}_{ix}",
                prompt_or_input_text=str(item["text"]),
                references=item["gold"],
            )
            
            sample.entity = str(curr_entity)
            
            if eval_gold_comms:
                # we have the gold communities, so save it
                comm_1 = item['comm_1']
                comm_1 = [x.strip().rstrip().lstrip() for x in comm_1]
                comm_2 = item['comm_2']
                comm_2 = [x.strip().rstrip().lstrip() for x in comm_2]
                sample.comm_1 = comm_1
                sample.comm_2 = comm_2
            
            samples.append(sample)
        print("We have these many samples " + str(len(samples)) + " in split " + str(split))
        sys.stdout.flush()
        pool_instance = cls(samples)
        return pool_instance
   
# rl4lms_registry.DataPoolRegistry.add("community_analysis_dataset", CommunityAnalysisDataset)

metric_map = {
    "rouge1": rouge1_metric,
    "rouge_combined": rouge_combined,
}

def get_gpu_memory_usage():
    # helper function to track during training 
    num_devices = torch.cuda.device_count()

    # Iterate over all available devices
    for device_id in range(num_devices):
        device = torch.device(f"cuda:{device_id}")
        t = torch.cuda.get_device_properties(device=device).total_memory
        r = torch.cuda.memory_reserved(device=device)
        a = torch.cuda.memory_allocated(device=device)
        f = r-a  # free inside reserved
        print("Out of total memory " + str(t) + " we have this much free " + str(f) + " on device " + str(device_id))


class CommunityAnalysisMetric(BaseMetric):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__()
        self.prompt = kwargs["prompt_path"]
        self.separator = kwargs["separator"]
        self.cache_path = kwargs["cache_path"]
        self.save_path = kwargs["save_path"]
        self.lambda_rouge_input = kwargs.get("lambda_rouge_input", 0.3)
        self.model_name = kwargs["gpt3_model_name"]
        
        # set these additional flags
        self.downstream_metric_name = "rouge1"
        self.downstream_metric = metric_map[self.downstream_metric_name]
        # how much importance to give to the rouge metric. this is a tradeoff, since more importance here means less importance to the RL stage. but less importance means the RL may start to generate non-gramattically correct focus areas
        self.rouge_weight = kwargs.get("rouge_weight", 0.4)
        self.penalize_doing_worse = kwargs.get("penalize_doing_worse", False)
        
        self.not_run_gpt_evaluation = kwargs.get("not_run_gpt_evaluation", False)
        
        
        self.frequency_based = False 
        self.entity_sentiment_reward = False
        self.sentiment_based = False
        
        self.similarity_based = False
        
        self.reward_fn = kwargs.get("reward_fn", "only_purity")
        
        print("Initialized kwargs is " + str(kwargs))
        self.openai_url = kwargs.get("openai_url", None)
        self._sentence_sim_model = None
        print("in rl4lms_custom_functions self.reward_fn is " + str(self.reward_fn))
        
        flair.device = torch.device('cpu') 
        self.flair_tagger = None
        self.ner_tagger = spacy.load("en_core_web_sm")
        torch.cuda.empty_cache()
        self.entailment_tokenizer = None
        self.entailment_model = None
        
        # Load cache from cache_path.
        if os.path.exists(self.cache_path):
            with open(self.cache_path, "r") as f:
                self.GPT3_CACHE = json.load(f)

    def remove_prefix(self, text: str, prefixes: List[str]=["Critique: ", "critique: ", "passage:"]):
        for prefix in prefixes:
            if text.startswith(prefix):
                return text[len(prefix):]
        return text

    def compute(
        self,
        prompt_texts: List[str],
        generated_texts: List[str],
        reference_texts: List[List[str]],
        meta_infos: List[Dict[str, Any]] = None,
        model: PreTrainedModel = None,
        split_name: str = None,
        epoch: int = None,
        gold_comm_1 = None, 
        gold_comm_2 = None,
        entities_discussed = None,
        only_consider_purity=False,
        return_success_failure=False,
        run_curriculum_reward=False,
        evaluating_test=False,
        reward_fn=None,
        not_run_gpt_evaluation=False,
        running_evaluation=False,
        run_no_focus_area=False,
        run_gold_focus_area=False,
        curr_curriculum=None,
        curr_epoch=None,
        openai_url=None
    ):
        # this function computes the reward for the RL model! and also the metrics to evaluate the models
        
        if not_run_gpt_evaluation:
            self.not_run_gpt_evaluation = True
        
        if reward_fn is not None:
            self.reward_fn = reward_fn
        
        if openai_url is not None:
            print("We are passed in openai_url " + str(openai_url))
            self.openai_url = openai_url
        success_predictions_dict, failure_predictions_dict = {}, {}
        
        self.aditional_reward_weight = 0.7
        individual_reward_weights = None
        if run_curriculum_reward or curr_curriculum is not None:
            # if we are running curriculum learning, this is the setup we use. In the paper, we ran curriculum learning for our best model
            self.reward_fn = 'all_purity;adjust_by_complexity;long_sentence;frequency_based'
            # print("Running curriculum reward " + " and self.reward_fn is now " + str(self.reward_fn))
            if curr_curriculum == 0:
                individual_reward_weights = {'all_purity': 0.25, 'adjust_by_complexity': 0.25, 'long_sentence': 0.25, 'frequency_based': 0.25}
            if curr_curriculum == 1:
                # first epoch, don't run any additional reward
                individual_reward_weights = {'all_purity': 0.1, 'adjust_by_complexity': 0.3, 'long_sentence': 0.4, 'frequency_based': 0.2}
            elif curr_curriculum == 2:
                individual_reward_weights = {'all_purity': 0.1, 'adjust_by_complexity': 0.2, 'long_sentence': 0.3, 'frequency_based': 0.4}
            elif curr_curriculum >= 3:
                individual_reward_weights = {'all_purity': 0.1, 'adjust_by_complexity': 0.3, 'long_sentence': 0.6}
            elif curr_curriculum >= 4:
                # also add model difference
                individual_reward_weights = {'all_purity': 0.3, 'adjust_by_complexity': 0.3, 'long_sentence': 0.4}
                self.reward_fn = 'all_purity;adjust_by_complexity;long_sentence;frequency_based;model_difference'
                
        if 'only_purity' in self.reward_fn:
            only_consider_purity = True
        else:
            only_consider_purity = False 
            
        if 'all_purity' in self.reward_fn:
            only_consider_all_purity = True
        else:
            only_consider_all_purity = False 
            
        if 'model_difference' in self.reward_fn:
            # this can't be done without purity, let's do it with default purity for now
            self.penalize_doing_worse = True
            self.only_consider__all_purity = True 
        else:
            self.penalize_doing_worse = False 
        
        if 'frequency_based' in self.reward_fn:
            # entity frequency reward function
            self.frequency_based = True 
        else:
            self.frequency_based = False
        if 'sentiment_based' in self.reward_fn:
            self.sentiment_based = True 
        else:
            self.sentiment_based = False
            
        
        if 'complexity' in self.reward_fn:    
            self.adjust_by_complexity = True
        else:
            self.adjust_by_complexity = False
            
        if 'long_sentence' in self.reward_fn:
            self.long_sentence = True 
        else:
            self.long_sentence = False
            
        # Strip off task prefix
        inputs = [self.remove_prefix(prompt) for prompt in prompt_texts]
        
        # this is the question that is asked to LLM_task to tell it to form the communities.
        llm_question_text = "Which users have the same perspective? Build two communities, Community 1 and Community 2. Community 1 should be the largest group of users possible that all have the same perspective. Community 2 are all the other users, which may not have the same perspective. First output 'Community 1', then it's usernames, then output 'Community 2', and then it's usernames. Make sure in your output that each FULL USERNAME is on a new line, with no other text on that line."
        
        # Prepend prompt.
        if run_no_focus_area:
            # allows us to test our models if no focus areas are provided
            print("Running no focus area")
            input_wfeed = [
                (
                    str(llm_question_text) + '\n'
                    + input_text.replace('Which users have the same perspective?\n', '')
                )
                for input_text, feedback_pred in zip(inputs, generated_texts)
            ]
        elif run_gold_focus_area:
            print("Running gold focus area")
            input_wfeed = [
                (
                    
                    str(llm_question_text) 
                    + " They are focusing on: "
                    + feedback_pred + '\n'
                    + input_text.replace('Which users have the same perspective?\n', '')
                )
                for input_text, feedback_pred in zip(inputs, reference_texts)
            ]
        else:
            
            input_wfeed = [
                (
                    
                    str(llm_question_text) 
                    + " They are focusing on: "
                    + feedback_pred + '\n'
                    + input_text.replace('Which users have the same perspective?\n', '')
                )
                for input_text, feedback_pred in zip(inputs, generated_texts)
            ]
        ATTEMPTS = 0
        HITS = 0

        # Query GPT-3 and get the communities
        if not self.not_run_gpt_evaluation:
            print("openai_url is " + str(self.openai_url))
            communities_gotten = llm_prompts.prompt_lm_determine_communities_given_input(input_wfeed, openai_url=self.openai_url)
            
        
        reward_name_mapping = {}
            
        # compute the metrics (these are all of our rewards)
        gpt_input = [input_text.replace('Which users have the same perspective?\n', '') for input_text in inputs]
        if ('purity' in self.reward_fn or evaluating_test) and not self.not_run_gpt_evaluation:
            only_consider_purity = False
            only_consider_all_purity = True
            if return_success_failure:
                # ref_texts_to_evaluate are the focus areas, in this case they are the generated ones
                scores, success_predictions_dict, failure_predictions_dict, num_users_predicted = llm_reward_fns.gpt_reward(gpt_output=communities_gotten, gpt_input=gpt_input, gold_comm_1=gold_comm_1, gold_comm_2=gold_comm_2, entities_discussed=entities_discussed, disable_tqdm=True, evaluate_bias_reward_accuracy=False, just_eval_purity=only_consider_purity, only_consider_all_purity=only_consider_all_purity, return_predictions_dict=True, input_wfeed=input_wfeed, ref_texts_to_evaluate=generated_texts, running_evaluation=running_evaluation)
            else:
                scores, num_users_predicted = llm_reward_fns.gpt_reward(gpt_output=communities_gotten, gpt_input=gpt_input, gold_comm_1=gold_comm_1, gold_comm_2=gold_comm_2, entities_discussed=entities_discussed, disable_tqdm=True, evaluate_bias_reward_accuracy=False, just_eval_purity=only_consider_purity, only_consider_all_purity=only_consider_all_purity, return_predictions_dict=False, running_evaluation=running_evaluation)
                
            reward_name_mapping['all_purity'] = scores['analysis_reward']
            reward_name_mapping['num_users_predicted_average'] = np.mean(num_users_predicted)
            reward_name_mapping['num_users_predicted_std'] = np.std(num_users_predicted)
            reward_name_mapping['num_users_predicted_min'] = np.min(num_users_predicted)
            reward_name_mapping['num_users_predicted_max'] = np.max(num_users_predicted)
        else:
            scores = {}
          
        if self.downstream_metric_name == "rouge1":
            additional_scores = self.downstream_metric(generated_texts, reference_texts, return_dict=True)
        else:
            additional_scores = self.downstream_metric(generated_texts, reference_texts)
        
        avg_additional_score = []
        for k, c_score in additional_scores.items():
            avg_additional_score.append(c_score)
            scores[k] = c_score
            # print("Adding score for key " + str(k) + " and score " + str(c_score))
        if len(avg_additional_score) == 0:
            new_rouge_score = 0
        else:
            new_rouge_score = np.mean(avg_additional_score)
        scores['rouge_score'] = new_rouge_score
        reward_name_mapping['rouge'] = new_rouge_score
            
        import torch
        if self.flair_tagger is None:
            import flair
            flair.device = torch.device('cpu') 
            # self.flair_tagger = SequenceTagger.load('flair/ner-english-fast').to(torch.device('cpu'))
            self.ner_tagger = spacy.load("en_core_web_sm")
            torch.cuda.empty_cache()
        torch.cuda.empty_cache()
            
        if self.frequency_based:
            # entity reward 
            curr_entity_sentiment_reward, entity_match_score, entity_match_score_all_entities, total_similarity, reward_total_entity_only_to_return = llm_reward_fns.new_rewards(self.flair_tagger, input_wfeed, gold_comm_1=gold_comm_1, gold_comm_2=gold_comm_2, entities_discussed=entities_discussed, ref_texts_to_evaluate=generated_texts, frequency_based=True, ner_tagger=self.ner_tagger, evaluating_test=evaluating_test, curr_curriculum=curr_curriculum)
            torch.cuda.empty_cache()
            # adjust the reward by the total
            if curr_entity_sentiment_reward is not None:
                scores['sentiment_frequency_sentiment_similarity_reward'] = curr_entity_sentiment_reward
                reward_name_mapping['frequency_based'] = curr_entity_sentiment_reward
            scores['only_curr_entity_sentiment_reward'] = reward_total_entity_only_to_return
            if len(entity_match_score) > 0:
                scores['entity_match_score'] = np.mean(entity_match_score)
                scores['entity_match_score_all_entities'] = np.mean(entity_match_score_all_entities)
            else:
                scores['entity_match_score'] = 0
                scores['entity_match_score_all_entities'] = 0
                
            
        if self.adjust_by_complexity:
            # informativeness reward
            # ajdust the reward by complexity
            curr_entity_sentiment_reward, entity_match_score, entity_match_score_all_entities, total_similarity, reward_total_entity_only_to_return = llm_reward_fns.new_rewards(self.flair_tagger, input_wfeed, gold_comm_1=gold_comm_1, gold_comm_2=gold_comm_2, entities_discussed=entities_discussed, ref_texts_to_evaluate=generated_texts, frequency_based=False, ner_tagger=None, entailment_tokenizer=None, entailment_model=None, evaluating_test=evaluating_test, complexity_based=True)
            scores['complexity_reward'] = curr_entity_sentiment_reward
            reward_name_mapping['adjust_by_complexity'] = curr_entity_sentiment_reward   
                
        if self.long_sentence:
            # adjust the reward by how long the prediction is
            curr_entity_sentiment_reward, entity_match_score, entity_match_score_all_entities, total_similarity, reward_total_entity_only_to_return = llm_reward_fns.new_rewards(self.flair_tagger, input_wfeed, gold_comm_1=gold_comm_1, gold_comm_2=gold_comm_2, entities_discussed=entities_discussed, ref_texts_to_evaluate=generated_texts, frequency_based=False, ner_tagger=None, entailment_tokenizer=None, entailment_model=None, evaluating_test=evaluating_test, complexity_based=False, length_based=True)
            scores['length_reward'] = curr_entity_sentiment_reward
            reward_name_mapping['long_sentence'] = curr_entity_sentiment_reward
            
        # finally, adjust the rewards to have the right weightage
        if individual_reward_weights is None:
            all_rewards = []
            for given_reward_name, given_reward_score in reward_name_mapping.items():
                if 'rouge' in given_reward_name:
                    # rouge score gets updated at the end
                    continue 
                all_rewards.append(given_reward_score)
            # weight each reward equally
            all_rewards_score = np.mean(all_rewards)
        else:
            # get all_rewards_score based on weights in individual_reward_weights
            all_rewards_score = 0
            for given_reward_name, given_reward_weight in individual_reward_weights.items():
                if 'rouge' in given_reward_name:
                    # rouge score gets updated at the end
                    continue 
                given_reward_score = reward_name_mapping[given_reward_name]
                all_rewards_score += (given_reward_score * given_reward_weight)
        # compute the final reward
        # Normalize weights to ensure they sum up to 1
        if 'analysis_reward' in scores:
            values = [all_rewards_score, scores['analysis_reward'], reward_name_mapping['rouge']]
        else:
            # no purity reward
            values = [all_rewards_score, reward_name_mapping['rouge']]
        weight_for_each_reward = (1 - self.rouge_weight) / (len(values) - 1)
        weights = [weight_for_each_reward, weight_for_each_reward, self.rouge_weight]
        total_weight = sum(weights)
        normalized_weights = [wgt / total_weight for wgt in weights]
        # Compute the weighted sum with normalized weights for non-NaN values
        weighted_sum = sum(val * wgt for val, wgt in zip(values, normalized_weights))
        # Cap the weighted sum at a maximum value of 1
        final_reward = min(weighted_sum, 1)
        scores['analysis_reward'] = final_reward

            
        if self.penalize_doing_worse:
            # here, we penalize the  the model if it is doing WORSE than the default model, because the focus areas should be helpful!
            input_wfeed_wo_focus_area = [(str(llm_question_text) + input_text.replace('Which users have the same perspective?\n', '')) for input_text, feedback_pred in zip(inputs, generated_texts)]
            # get comms
            communities_gotten_wo_focus_area = llm_prompts.prompt_lm_determine_communities_given_input(input_wfeed_wo_focus_area)
            # only consider purity here
            gpt_input_wo_focus_area = [input_text.replace('Which users have the same perspective?\n', '') for input_text in input_wfeed_wo_focus_area]
            scores_wo_focus_area, _ = llm_reward_fns.gpt_reward(gpt_output=communities_gotten_wo_focus_area, gpt_input=gpt_input_wo_focus_area, gold_comm_1=gold_comm_1, gold_comm_2=gold_comm_2, entities_discussed=entities_discussed, disable_tqdm=True, evaluate_bias_reward_accuracy=False, just_eval_purity=only_consider_purity, only_consider_all_purity=only_consider_all_purity, return_predictions_dict=False, running_evaluation=running_evaluation)
            
            focus_area_purity_reward = scores['analysis_reward'] 
            # keep track of the original reward without changing it
            scores['original_reward'] = scores['analysis_reward'] 
            if scores_wo_focus_area['analysis_reward'] > focus_area_purity_reward and (scores_wo_focus_area['analysis_reward'] - focus_area_purity_reward) > 0.3:
                # we have higher purity without focus area and it's there by a significant margin -> TODO: change the margin
                # the original model had higher purity
                purity_difference = scores_wo_focus_area['purity_reward'] - focus_area_purity_reward
                # penalize the reward
                scores['analysis_reward'] = scores['analysis_reward'] - (2 * purity_difference)
                if scores['analysis_reward'] <= -1.0:
                    scores['analysis_reward'] = -1.0
                scores['scores_wo_focus_area_purity'] = scores_wo_focus_area['analysis_reward'] 

        if self.save_path != "" and split_name in ["test", "val"]:
            save_path = os.path.join(
                self.save_path, f"{split_name}community_metric_{epoch}.json"
            )
            if not self.not_run_gpt_evaluation:
                with open(save_path, "w") as f:
                    json.dump(communities_gotten, f)


        metric_dict = {}
        for k, score in scores.items():
            metric_dict.update({f"custom_metrics/community_metric_{k}": (None, score)})
            
        if return_success_failure:
            return metric_dict, success_predictions_dict, failure_predictions_dict 
            
        return metric_dict

def build_tokenizer(tokenizer_config: Dict[str, Any]):
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_config["model_name"])
    if tokenizer.pad_token is None and tokenizer_config.get(
        "pad_token_as_eos_token", True
    ):
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = tokenizer_config.get("padding_side", "left")
    tokenizer.truncation_side = tokenizer_config.get("truncation_side", "left")
    return tokenizer

   
class CommunityRewardFunction(RewardFunction):
   def __init__(self, *args, **kwargs) -> None:
       self.tokenizer_config = kwargs["tokenizer"]
       self.tokenizer = build_tokenizer(self.tokenizer_config)
       self.metric = CommunityAnalysisMetric(**kwargs["metric"])
       self.only_consider_purity = kwargs.get("only_consider_purity", True)
       super().__init__()

   def __call__(self, prev_observation: Observation,
                action: int,
                current_observation: Observation,
                done: bool,
                meta_info=None,
                curr_curriculum=None, curr_epoch=None) -> float:
       
       if done:
                   
            state = current_observation.input_encoded_pt
            input_wfeed = self.tokenizer.decode(state[0], skip_special_tokens=True)
                    
            # Get prompt and feedback separately.
            prompt_or_input_text = prev_observation.prompt_or_input_text
            
            
            # remove the prompt so we only have the feedback
            # feedback_pred = input_wfeed.replace('\n', '').lstrip(prompt_or_input_text.replace('\n', ''))
            # Find the index where the differing part starts
            feedback_pred = current_observation.context_text

            # remove the critique
            prompt_or_input_text = prompt_or_input_text.lstrip("Critique: ")
            edit_gold = current_observation.target_or_reference_texts
            
            
            gold_comm_1 = current_observation.gold_comm_1
            gold_comm_2 = current_observation.gold_comm_2
        
            metric_dict = self.metric.compute(prompt_texts=[prompt_or_input_text], generated_texts=[feedback_pred], reference_texts=[edit_gold], gold_comm_1=[gold_comm_1], gold_comm_2=[gold_comm_2], entities_discussed=[current_observation.entity], only_consider_purity=self.only_consider_purity, curr_curriculum=curr_curriculum, curr_epoch=curr_epoch)
            reward = metric_dict[f"custom_metrics/community_metric_analysis_reward"][-1]
            
            sys.stdout.flush()
            
            return reward
           
       return 0
   