def deregister_torch_ipc():
    from multiprocessing.reduction import ForkingPickler
    import torch
    ForkingPickler._exi_reducers.pop(torch.cuda.Event)
    for t in torch._storage_classes:
        ForkingPickler._extra_reducers.pop(t)
    for t in torch._tensor_classes:
        ForkingPickler._extra_reducers.pop(t)
    ForkingPickler._extra_reducers.pop(torch.Tensor)
    ForkingPickler._extra_reducers.pop(torch.nn.parameter.Parameter)
import torch
import time
import torch.nn.functional as F
import warnings
from flair.data import Sentence
from flair.models import SequenceTagger
import math  
import sys 
sys.path.append("..")
torch.manual_seed(0)
import logging

logger = logging.getLogger("spacy")
logger.setLevel(logging.ERROR)
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification

import joblib


import requests
import datetime
from sentence_transformers import SentenceTransformer, util

import gc

# Contruct a two-layer GNN model
import torch.nn.functional as F
import time
import numpy as np
import torch.nn.functional as F
from collections import defaultdict
import sklearn.metrics as skm
import json
import os
from tqdm.auto import tqdm
from collections import Counter
import warnings
warnings.filterwarnings('once')
from allennlp.models import * 


import torch
torch.manual_seed(0)
sk_learn_seed = 5
import numpy as np

np.random.seed(0)
torch.multiprocessing.set_sharing_strategy('file_system')
import analysis_helper_functions
import llm_prompts
sys.path.append("..")
import reddit_data

import openai
logging.getLogger("urllib3").setLevel(logging.WARNING)
openai.util.logging.getLogger().setLevel(logging.WARNING)

import warnings
warnings.filterwarnings("ignore", message="Mean of empty slice.", category=RuntimeWarning)
warnings.filterwarnings("ignore", message="invalid value encountered in double_scalars", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=UserWarning, module="spacy.util")

warnings.filterwarnings("ignore", category=ResourceWarning, message="unclosed file")


def load_dict(dict_path):
    '''helper function to load a dictionary'''
    out_dict = defaultdict(list)
    old_out_dict = np.load(dict_path, allow_pickle=True)
    out_dict.update(old_out_dict.item())
    return out_dict


def determine_purity_for_community_all_match(given_community, gold_community, user_id_labels_dict_pred=None, user_id_labels_dict_pred_path=None, given_community_pred_labels=None, given_community_gold_labels=None, user_id_labels_dict_gold=None, evaluate_reddit=False):
    '''given a community, determine the purity using user predicted labels. This is the primary evaluation metric we used!'''
    
    if user_id_labels_dict_pred_path is not None and user_id_labels_dict_pred is None:
        # if we don't have the user labels dict, load it
        user_id_labels_dict_pred = load_dict(user_id_labels_dict_pred_path)
        
    given_community = [x.lower() for x in given_community]
    
    all_community_labels = []
    for given_user in given_community:
        if given_community_pred_labels is not None:
            if given_user.lower() in given_community_gold_labels:
                all_community_labels.append(given_community_gold_labels[given_user.lower()])
        else:
            all_community_labels.append(analysis_helper_functions.determine_label_given_label(user_id_labels_dict_gold[given_user]))
    
    
    if len(all_community_labels) < 1:
        print("We don't have any labels for these users why " + str(given_community))
    try:
        community_label = Counter(all_community_labels).most_common(1)[0][0]
    except:
        print("Problem, community label")
        print(all_community_labels)
        return 0.0, 0.0, 0.0, 0.0
    
    y_pred = [] 
    y_true = []
    all_users_to_consider = []
    all_users_to_consider.extend(given_community)
    for z in gold_community:
        if z not in all_users_to_consider:
            all_users_to_consider.append(z)
    for given_user in all_users_to_consider:
        if given_community_pred_labels is None or not evaluate_reddit:
            try:
                y_pred_label = analysis_helper_functions.determine_label_given_label(user_id_labels_dict_pred[given_user])
            except:
                # print("We don't have user " + str(given_user) + " in the pred labels dict")
                continue
        else:
            if given_user.lower() not in given_community_pred_labels:
                # we don't have the user label, LLM_task made a mistake, move on
                continue
            y_pred_label = given_community_pred_labels[given_user.lower()]
        
        y_true_label = given_community_gold_labels[given_user.lower()]

        y_true.append(y_true_label)
        y_pred.append(y_pred_label)
        
    purity_score_for_community = skm.accuracy_score(y_true, y_pred) 
        
        
    # optional, also compute precision, recall, and f1 as additional metrics
    tp = 0
    fp = 0
    fn = 0
    for given_user in given_community:
        if given_community_pred_labels is None or not evaluate_reddit:
            try:
                y_pred_label = analysis_helper_functions.determine_label_given_label(user_id_labels_dict_pred[given_user])
            except:
                continue
        else:
            if given_user.lower() not in given_community_pred_labels:
                # we don't have the user label, just move on
                continue
            y_pred_label = given_community_pred_labels[given_user.lower()]
            
        y_true_label = given_community_gold_labels[given_user.lower()]
        if given_user in gold_community:
            tp += 1
        else:
            fp += 1
    for user in gold_community:
        if user not in given_community:
            fn += 1
    precision = tp / (tp + fp) if (tp + fp) != 0 else 0  # Precision
    recall = tp / (tp + fn) if (tp + fn) != 0 else 0     # Recall
    f1_score = 2 * ((precision * recall) / (precision + recall)) if (precision + recall) != 0 else 0  # F1 Score

    return purity_score_for_community, precision, recall, f1_score
    

def community_reward_focus_area_length_based(predicted_focus_area):
    ''' this is our length based focus area reward. Given a focus area, compuate a reward based on how many words are in it'''
    
    def calculate_reward_length(num_words):
        min_words = 10
        max_words = 40
        min_reward = 0.5
        max_reward = 1.0
        
        if num_words <= min_words:
            return min_reward
        elif num_words >= max_words:
            return max_reward
        else:
            return ((num_words - min_words) / (max_words - min_words)) * (max_reward - min_reward) + min_reward

    
    num_words = len(predicted_focus_area.split(' '))
    return calculate_reward_length(num_words=num_words)
    
    

def community_reward_focus_area_complexity(predicted_focus_area, sim_score_threshold=0.3, loaded_model=None, loaded_vectorizer=None):
    '''This is our complexity reward. It needs the model trained in generate_complex_sentences_given_simple.py and passed in the path'''
    
    def predict_complexity(model, vectorizer, new_sentences, train_keras_model=False):
        new_sentences_vec = vectorizer.transform(new_sentences)
        if train_keras_model:
            class_probabilities = model.predict(new_sentences_vec)
        else:
            complexity_scores = model.predict_proba(new_sentences_vec)[:, 1]
        return complexity_scores

    # Inference on new sentences
    new_sentences = [predicted_focus_area]
    complexity_scores = predict_complexity(loaded_model, loaded_vectorizer, new_sentences)

    average_score = []
    for sentence, score in zip(new_sentences, complexity_scores):
        # print(f"{sentence}: {score}")
        average_score.append(score)

    return (1 - np.mean(average_score))


def extract_entities_from_text(text, flair_tagger=None, ner_tagger=None, working_with_reddit=False):

    '''given a text and a tagger, extract the entities and add it to the dictionary provided. 

    entities_dict can be empty or have the entities that have been extracted from previous text spans

    text_to_link_entity_article is the text that you want to be linked to the entity
    '''
    import flair, torch
    if not working_with_reddit:
        flair.device = torch.device('cpu') 
    
    common_entities = ['user', 'perspective', 'issues', 'discussing', 'topic', 'users', 'issue']
    
    # start_time = time.time()
    # load the NER tagger
        
    # # run NER over sentence
    curr_entities_dict = {}
    
    if ner_tagger is not None:
        doc = ner_tagger(str(text))
        concepts = [token.text for token in doc if token.pos_ in ["NOUN", "PROPN"]]
        for entity in concepts:
            curr_entities_dict[entity.lower()] = 1


    if ner_tagger is None:
        if not working_with_reddit:
            flair_tagger = flair_tagger.to('cpu')
        
        sentence_to_tag = Sentence(text)
        flair_tagger.predict(sentence_to_tag)
        for entity in sentence_to_tag.get_spans('ner'):
            # add each predicted entity to the output entities dict
            curr_entities_dict[entity.text.lower()] = 1
    
    # remove the common entities from curr_entities_dict
    for given_entity in common_entities:
        if given_entity in curr_entities_dict:
            del curr_entities_dict[given_entity]
            
    # end_time = time.time()
    # execution_time = end_time - start_time
    # print("It takes this long to get the entities " + str(execution_time))
    gc.collect()

    return curr_entities_dict
    

def community_reward_entities_sentiment(flair_tagger, predicted_focus_area, all_usernames, summaries_dict=None, entity_discussed=None, given_gold_comm_1=None, given_gold_comm_2=None, sentiment_based=False, frequency_based=False, ner_tagger=None, entailment_tokenizer=None, entailment_model=None, num_entities_for_max_reward=3, curr_curriculum=None):
    '''given a set of users all_usernames (not model predicted so that we don't have to call Chat-GPT for this reward) and their user summaries, and a focus area predicted_focus_area, compute a reward based on entities and sentiment
    
    first, find the most common entities in the text
    then, for each entity determine if it's an impactful one 
        frequency_based: does an entity appear a lot in one community and not a lot in the other?
    finally, compute a reward based on that 
    '''
    all_usernames = [x.lower() for x in all_usernames]
    given_gold_comm_1 = [x.lower() for x in given_gold_comm_1]
    given_gold_comm_2 = [x.lower() for x in given_gold_comm_2]
    
    user_entity_mapping_dict = {}
    entity_user_mapping_dict = {}
    if flair_tagger is not None:
        flair_tagger = flair_tagger.to('cpu')
    # go through each user and determine the entities that they discuss
    for given_user in all_usernames:
        user_entity_mapping_dict[given_user] = []
        # get the entities from their summary
        if given_user not in summaries_dict:
            continue 
        
        # get the entities
        curr_entities_dict = extract_entities_from_text(summaries_dict[given_user], flair_tagger, ner_tagger)
        # print("From text " + str(summaries_dict[given_user]) + " curr_entities_dict is " + str(curr_entities_dict))
        for given_entity in curr_entities_dict:
            # save the entity for the user
            user_entity_mapping_dict[given_user].append(given_entity)
            # store the entity-> user mapping also
            if given_entity not in entity_user_mapping_dict:
                entity_user_mapping_dict[given_entity] = []
            entity_user_mapping_dict[given_entity].append(given_user)
            
      
    # for each entity determine if it's an impactful one 
    impactful_entities = {}
    non_impactful_entities = {}
    entity_difference_mapping = {}
    
    dicts_to_iterate = [entity_user_mapping_dict]
    # what is the threshold for how the entity has to be discussed differently by each of the gold comms to trigger a reward
    threshold_for_difference = 2
    for dict_to_iterate in dicts_to_iterate:
        for given_entity, users_discussing_it in dict_to_iterate.items():
            comm_1_users_discussing = 0
            comm_2_users_discussing = 0
            for given_user_discussing_it in users_discussing_it:
                if given_user_discussing_it in given_gold_comm_1:
                    comm_1_users_discussing += 1
                elif given_user_discussing_it in given_gold_comm_2:
                    comm_2_users_discussing += 1
                else:
                    continue
            entity_difference_mapping[given_entity] = np.abs(comm_2_users_discussing - comm_1_users_discussing)
            
            if np.abs(comm_2_users_discussing - comm_1_users_discussing) >= threshold_for_difference:
                # this entity is discussed by one comm more than the other
                if comm_1_users_discussing > comm_2_users_discussing:
                    # it's discussed more by comm 1
                    impactful_entities[given_entity] = comm_1_users_discussing
                else:
                    # it's discussed more by comm 2
                    impactful_entities[given_entity] = comm_2_users_discussing
            if comm_1_users_discussing >= threshold_for_difference and comm_2_users_discussing >= threshold_for_difference:
                # it's being discussed a lot by both communities, it's clearly not impactful
                non_impactful_entities[given_entity] = 1
                    
            
                    
    # see if any of the impactful_entities are mentioned in the focus area
    focus_area_entities = reddit_data.extract_entities_from_text(predicted_focus_area, flair_tagger, ner_tagger)
    
    if curr_curriculum is not None:
        if curr_curriculum == 1:
            num_entities_for_max_reward = 4
        elif curr_curriculum == 2:
            num_entities_for_max_reward = 6
        elif curr_curriculum >= 3:
            num_entities_for_max_reward = 7
    
    # compute a score of how often the impactful_entities are mentioned in focus_area_entities
    return_score = []
    reward_to_return = None
    # count how many impactful_entities are in focus_area_entities
    num_impactful_entities_in_focus_area_entities = 0
    for given_impactful_entities in impactful_entities:
        if given_impactful_entities in focus_area_entities:
            # one of the entities is in the focus area, return True
            num_impactful_entities_in_focus_area_entities += 1
    if len(impactful_entities) > num_entities_for_max_reward:
        if num_impactful_entities_in_focus_area_entities > 0:
            # focus area had num_entities_for_max_reward (num_impactful_entities_in_focus_area_entities) out of num_entities_for_max_reward impactful entities
            reward_to_return = max(1.0, (num_impactful_entities_in_focus_area_entities / num_entities_for_max_reward))
        else:
            # there were more than num_entities_for_max_reward impactful entities and you had 0, low reward
            reward_to_return = 0
    else:
        # there aren't num_entities_for_max_reward impactful entities
        # we don't have num_entities_for_max_reward impactful entities
        if len(impactful_entities) == 0:
            # if there are 0, positive reward
            reward_to_return = 1
        else:
            # else, there are more than 0 so let's see how many you have
            reward_to_return = max(1.0, (num_impactful_entities_in_focus_area_entities / len(impactful_entities)))
    if reward_to_return is None:
        reward_to_return = 0
    reward_to_return = float(reward_to_return)
    
    if num_impactful_entities_in_focus_area_entities == 0:
        # if we didn't predict any entities, return a 0 reward
        reward_to_return = 0.0
    
    
    for given_impactful_entities in impactful_entities:
        if given_impactful_entities in focus_area_entities:
            # one of the entities is in the focus area, return True
            return_score.append(1.0)
        else:
            return_score.append(0.0)

    # compute a score of how often any entity is mentioned in focus_area_entities
    return_score_all_entities = []
    for dict_to_iterate in dicts_to_iterate:
        for given_entity in dict_to_iterate:
            if given_entity in focus_area_entities:
                # one of the entities is in the focus area, return True
                return_score_all_entities.append(1.0)
            else:
                return_score_all_entities.append(0.0)
    
    return reward_to_return, return_score, return_score_all_entities
       

def new_rewards(flair_tagger, input_wfeed, gold_comm_1=None, gold_comm_2=None, entities_discussed=None, ref_texts_to_evaluate=None, frequency_based=False,ner_tagger=None, evaluating_test=False, complexity_based=False, length_based=False, curr_curriculum=None):
    # handles all our reward functions except for the purity/downstream metric evaluation one
    if evaluating_test:
        disable_tqdm = False 
    else:
        disable_tqdm = True
    reward_total = []
    reward_total_entity_only = []
    start_time = time.time()
    entity_match_score = []
    entity_match_score_all_entities = []
    total_similarity = []
    
    
    if complexity_based:
        def load_model_and_vectorizer(model_path='complexity_model.pkl', vectorizer_path='complexity_vectorizer.pkl'):
            model = joblib.load(model_path)
            vectorizer = joblib.load(vectorizer_path)
            return model, vectorizer
        
        loaded_model, loaded_vectorizer = load_model_and_vectorizer()
    
    for given_input, given_gold_comm_1, given_gold_comm_2, entity_discussed, curr_text_input, curr_focus_area in tqdm(zip(input_wfeed, gold_comm_1, gold_comm_2, entities_discussed, input_wfeed, ref_texts_to_evaluate), disable=disable_tqdm):
        given_input = given_input.lower()
        
        # parse the usernames and summaries from the prompt text. this will help us with the rewards
        all_usernames = []
        all_usernames_summaries = {}   
        for given_message in given_input.split('\n'):
            given_message = given_message.lower()
            if len(given_message.strip().rstrip()) == 0:
                continue
            if 'username:' in given_message.lower():
                if 'summary' in given_message.lower():
                    curr_username = given_message.split('username: ')[1].split('summary:')[0]  
                else:
                    try:
                        curr_username = given_message.split('username: ')[1].split(' ')[0]
                    except Exception as e:
                        print("In llm_reward_fns, can't get usnermae from given_message " + str(given_message))
                        continue
            elif given_message.split(' ')[0].lower() != 'which':
                curr_username = given_message.split(' ')[0]
            else:
                continue
                
            if len(curr_username.split(' ')) > 0:
                # only consider the first word
                curr_username = curr_username.split(' ')[0]
                
            curr_username_summary = given_message.replace('username:', '').replace(curr_username, '').replace('summary:', '').rstrip().lstrip().strip()
                
            all_usernames.append(curr_username)
            all_usernames_summaries[curr_username] = curr_username_summary
        
        
        if frequency_based:
            # this is our entity reward
            num_entities_for_max_reward = 3
            
            curr_sentiment_reward, curr_entity_match_score, curr_entity_match_score_all_entities = community_reward_entities_sentiment(flair_tagger, predicted_focus_area=curr_focus_area, all_usernames=all_usernames, summaries_dict=all_usernames_summaries, entity_discussed=entity_discussed, given_gold_comm_1=given_gold_comm_1, given_gold_comm_2=given_gold_comm_2, sentiment_based=False, frequency_based=True, ner_tagger=ner_tagger, entailment_tokenizer=None, entailment_model=None, num_entities_for_max_reward=num_entities_for_max_reward, curr_curriculum=curr_curriculum)
            entity_match_score.extend(curr_entity_match_score)
            entity_match_score_all_entities.extend(curr_entity_match_score_all_entities)
                        
            if curr_sentiment_reward is not None:
                curr_reward = curr_sentiment_reward
                # if curr_sentiment_reward == True:
                #     curr_reward = 1.0
                # else:
                #     curr_reward = 0.0
                reward_total_entity_only.append(curr_reward)
                reward_total.append(curr_reward)
               
        if complexity_based:
            # this is actually total_complexity
            curr_reward = community_reward_focus_area_complexity(predicted_focus_area=curr_focus_area, loaded_model=loaded_model, loaded_vectorizer=loaded_vectorizer)        
            if curr_reward is not None:
                reward_total.append(curr_reward)
                
        if length_based:
            curr_reward = community_reward_focus_area_length_based(predicted_focus_area=curr_focus_area)        
            if curr_reward is not None:
                reward_total.append(curr_reward)
            
                
    end_time = time.time()
    execution_time = end_time - start_time
    if len(reward_total_entity_only) == 0:
        reward_total_entity_only_to_return = None
    else:
        reward_total_entity_only_to_return = np.mean(reward_total_entity_only)
    if len(reward_total) == 0:
        return None, entity_match_score, entity_match_score_all_entities, total_similarity, reward_total_entity_only_to_return
    try:
        return np.mean(reward_total), entity_match_score, entity_match_score_all_entities, total_similarity, reward_total_entity_only_to_return
    except Exception as e:
        print("Problem returning reward " + str(reward_total) + " and entity_match_score " + str(entity_match_score) + " and entity_match_score_all_entities " + str(entity_match_score_all_entities) + " and total_similarity " + str(total_similarity))
        exit(1)
        return None, entity_match_score, entity_match_score_all_entities, total_similarity
    

def gpt_reward(gpt_output, gpt_input, gold_comm_1=None, gold_comm_2=None, entities_discussed=None, reddit_data=False, return_rewards=False, run_bias=True, evaluate_reddit=True, user_id_labels_dict_pred=None, user_id_labels_dict_gold=None, user_id_labels_dict_bias=None, return_predictions_dict=False, input_wfeed=None, ref_texts_to_evaluate=None, disable_tqdm=False, just_eval_purity=False, evaluate_bias_reward_accuracy=True, only_consider_all_purity=False, running_evaluation=False):   
    '''this function computes our downstream metric reward, i.e. how well we can classify the communities and our purity''' 

    purity_weight = 0
    all_match_weight = 1
    
    
    if gold_comm_1 is None:
        # we don't have, so make a list of None
        gold_comm_1 = [None for _ in range(len(gpt_input))]
        gold_comm_2 = [None for _ in range(len(gpt_input))]
        entities_discussed = [None for _ in range(len(gpt_input))]
    if input_wfeed is None:
        input_wfeed = [None for _ in range(len(gpt_input))]
    if ref_texts_to_evaluate is None:
        ref_texts_to_evaluate = [None for _ in range(len(gpt_input))]
    
    all_usernames = []
    all_rewards = []
    purity_rewards, purity_rewards_by_num_users = [], []
    precision_rewards_all_match, recall_rewards_all_match, f1_rewards_all_match = [], [], []
    purity_rewards_all_match = []
    perspective_rewards = []
    gpt_bias_rewards = []
    gold_bias_rewards = []
    gpt_bias_accuracy = []
    num_users_predicted = []
    
    if return_predictions_dict:
        # we want to return the predictions that we made
        # success and failure is based on purity
        success_predictions_dict = {}
        failure_predictions_dict = {}
    
    assert len(gpt_output) == len(gpt_input) == len(gold_comm_1) == len(gold_comm_2) == len(entities_discussed) == len(input_wfeed) == len(ref_texts_to_evaluate)
    
    if running_evaluation:
        print("Running evaluation, so we will print out the tqdm progress bar so that we can track our progress easier in case the dataset is big")
        disable_tqdm = False
    for given_comm, given_input, given_gold_comm_1, given_gold_comm_2, entity_discussed, curr_text_input, curr_focus_area in tqdm(zip(gpt_output, gpt_input, gold_comm_1, gold_comm_2, entities_discussed, input_wfeed, ref_texts_to_evaluate), total=len(gpt_output), disable=disable_tqdm):
        # print("Working with given_comm " + str(given_comm))
        
        if '__date__' in entity_discussed:
            entity_discussed = entity_discussed.split('__date__')[0]
        
        if evaluate_reddit:
            curr_gold_bias_labels = {}
            for u_1 in given_gold_comm_1:
                curr_gold_bias_labels[u_1.lower()] = 'left'
            for u_2 in given_gold_comm_2:
                curr_gold_bias_labels[u_2.lower()] = 'right'
        
        if ';;;;' in given_comm:
            predicted_comm_1 = given_comm.split(';;;;')[0].split(',')
            predicted_comm_2 = given_comm.split(';;;;')[1].split(',')
            predicted_comm_1 = [x.lower() for x in predicted_comm_1]
            predicted_comm_2 = [x.lower() for x in predicted_comm_2]
            
        original_given_gold_comm_1 = given_gold_comm_1.copy()
            
        if given_gold_comm_1 != None:
            given_gold_comm_1 = [x.lower() for x in given_gold_comm_1]
            given_gold_comm_2 = [x.lower() for x in given_gold_comm_2]
            
            
        predicted_comm_1 = [x.replace('\n', '').lstrip().rstrip().strip() for x in predicted_comm_1]
        predicted_comm_1 = [x.lower() for x in predicted_comm_1 if len(x.strip().rstrip()) != 0]
        predicted_comm_2 = [x.replace('\n', '').lstrip().rstrip().strip() for x in predicted_comm_2]
        predicted_comm_2 = [x.lower() for x in predicted_comm_2 if len(x.strip().rstrip()) != 0]
        
        
        
        # remove duplicates
        predicted_comm_1 = list(set(predicted_comm_1))
        predicted_comm_2 = list(set(predicted_comm_2))
                 
        all_usernames = []
        all_usernames_summaries = {}   
        for given_message in given_input.split('\n'):
            given_message = given_message.lower()
            if len(given_message.strip().rstrip()) == 0:
                continue
            if 'username:' in given_message.lower():
                if 'summary' in given_message.lower():
                    curr_username = given_message.split('username: ')[1].split('summary:')[0]  
                else:
                    curr_username = given_message.split('username: ')[1].split(' ')[0]
            elif given_message.split(' ')[0].lower() != 'which':
                curr_username = given_message.split(' ')[0]
            else:
                continue
                
            if len(curr_username.split(' ')) > 0:
                # only consider the first word
                curr_username = curr_username.split(' ')[0]
                
            curr_username_summary = given_message.replace('username:', '').replace(curr_username, '').replace('summary:', '').rstrip().lstrip().strip()
                
            all_usernames.append(curr_username)
            all_usernames_summaries[curr_username] = curr_username_summary
            
        # print("We got these usernames " + str(all_usernames) + " from these input " + str(given_input))
                
        if len(predicted_comm_1) == 0 and len(predicted_comm_2) == 0:
            print("No users were predicted for " + str(given_comm))
            # no users, move on 
            continue 
    
        curr_accurate_users = []
    
        given_community_pred_labels = {}
        given_community_gold_labels = {}
        if evaluate_reddit:
            
            # determine how many users match each gold comm
            # swap the gold comms if necessary
            # this is fine because we generated data from just Democratic and Republican and that's the gold comms
            users_matching_gold_1 = 0
            users_matching_gold_2 = 0
            for given_user in all_usernames:
                if len(given_user.strip().rstrip().lstrip()) == 0:
                    continue
                if given_user in predicted_comm_1:
                    # we predicted this user
                    if given_user in given_gold_comm_1:
                        users_matching_gold_1 += 1
                    elif given_user in given_gold_comm_2:
                        users_matching_gold_2 += 1
            if users_matching_gold_2 > users_matching_gold_1:
                tmp = given_gold_comm_1.copy()
                given_gold_comm_1 = given_gold_comm_2.copy()
                given_gold_comm_2 = tmp.copy()
            
            # for reddit, get the community labels from the gold
            for given_user in all_usernames:
                if len(given_user.strip().rstrip().lstrip()) == 0:
                    continue
                if given_user in predicted_comm_1:
                    # print("given_user " + str(given_user) + " is in predicted_comm_1")
                    # it was predicted in the first community 
                    given_community_pred_labels[given_user] = 1
                elif given_user in predicted_comm_2:
                    # it was predicted in the second community
                    given_community_pred_labels[given_user] = 2
                else:
                    print("We don't have a label for this user " + str(given_user))
                    given_community_pred_labels[given_user] = 3
                    
                if given_user in given_gold_comm_1:
                    # it was gold in the first community 
                    given_community_gold_labels[given_user] = 1
                    if given_user in predicted_comm_1:
                        curr_accurate_users.append(1.0)
                    else:
                        curr_accurate_users.append(0.0)
                elif given_gold_comm_2:
                    # it was gold in the second community
                    given_community_gold_labels[given_user] = 2
                else:
                    # TODO: this shouldn't happen!
                    print("We don't have a label for this user " + str(given_user))
                    given_community_gold_labels[given_user] = 3
                    
        else:
            # fo twitter, get it from what has been saved
            given_community_pred_labels = user_id_labels_dict_pred
            given_community_gold_labels = user_id_labels_dict_gold
        if len(predicted_comm_1) == 0:
            continue
        
        purity_for_community_all_match, precision_all_match, recall_all_match, f1_all_match = determine_purity_for_community_all_match(predicted_comm_1, gold_community=given_gold_comm_1, given_community_pred_labels=given_community_pred_labels, given_community_gold_labels=given_community_gold_labels, evaluate_reddit=evaluate_reddit, user_id_labels_dict_pred=user_id_labels_dict_pred)
        
                
        if return_predictions_dict and purity_for_community is not None and purity_for_community > 0.8:
            # consider this a success
            # for this input (including the focus area), we predicted this community and we used this focus area
            # do it by original_given_gold_comm_1 so we can identify it later in case it gets swapped around
            success_predictions_dict[','.join(original_given_gold_comm_1)] = (curr_text_input, predicted_comm_1, curr_focus_area, purity_for_community, given_gold_comm_1, given_gold_comm_2, all_usernames, [], [], purity_for_community_all_match)
        elif return_predictions_dict:
            failure_predictions_dict[','.join(original_given_gold_comm_1)] = (curr_text_input, predicted_comm_1, curr_focus_area, purity_for_community, given_gold_comm_1, given_gold_comm_2, all_usernames, [], [], purity_for_community_all_match)
            
        if purity_for_community_all_match is None and len(given_gold_comm_1) >= 1:
            purity_for_community_all_match = 0.0
        purity_rewards_all_match.append(purity_for_community_all_match)
        precision_rewards_all_match.append(precision_all_match)
        recall_rewards_all_match.append(recall_all_match)
        f1_rewards_all_match.append(f1_all_match)
        num_users_predicted.append(len(predicted_comm_1))
        
        if only_consider_all_purity or just_eval_purity:
            continue
        
        
    if return_rewards:
        # returning all as a list!
        if return_predictions_dict:
            return purity_rewards, purity_rewards_all_match, gpt_bias_rewards, perspective_rewards, gold_bias_rewards, purity_rewards_by_num_users, success_predictions_dict, failure_predictions_dict
        else:
            return purity_rewards, purity_rewards_all_match, gpt_bias_rewards, perspective_rewards, gold_bias_rewards, purity_rewards_by_num_users
        
    if len(all_rewards) == 0:
        if len(all_usernames) > 0:
            print("Why is there no reward when we had usernames " + str(all_usernames))
        if len(gpt_output) > 0:
            print("Why is there no reward when we had gpt_output " + str(gpt_output))
        reward_to_return = 0
    else:
        reward_to_return = np.mean(all_rewards)
        
    def compute_normalized_weighted_sum(value1, value2, value3, value4, value5, weight1, weight2, weight3, weight4, weight5):
        # Normalize weights to ensure they sum up to 1
        values = [value1, value2, value3, value4, value5]
        weights = [weight1, weight2, weight3, weight4, weight5]
        
        # Filter out NaN values and corresponding weights
        non_nan_values = []
        non_nan_weights = []
        for val, wgt in zip(values, weights):
            if not math.isnan(val):
                # print("In llm_reward_fns, considering " + str(val) + " with weight " + str(wgt))
                non_nan_values.append(val)
                non_nan_weights.append(wgt)
        
        # Normalize weights to ensure they sum up to 1
        total_weight = sum(non_nan_weights)
        normalized_weights = [wgt / total_weight for wgt in non_nan_weights]
        
        # Compute the weighted sum with normalized weights for non-NaN values
        weighted_sum = sum(val * wgt for val, wgt in zip(non_nan_values, normalized_weights))
        
        # Cap the weighted sum at a maximum value of 1
        final_value = min(weighted_sum, 1)
                
        return final_value
    
    # for now everything has equal weight  
    
    reward_to_return = np.mean(purity_rewards_all_match)
    # if you had multiple reward values, you can normalize them, but we don't use that here
    
    # reward_to_return = np.mean(f1_rewards_all_match)
    # if np.random.randint(0, 10) == 0:
    #     print("Using reward from f1 score " + str(reward_to_return))

    # some of these will be NAN, these were custom metrics we used earlier to evaluate the model but didn't include in the final paper.     
    if return_predictions_dict:
        return {'analysis_reward': reward_to_return, 'purity_reward': np.mean(purity_rewards), 'bias_reward': np.mean(gpt_bias_rewards), 'perspective_reward': np.mean(perspective_rewards), 'purity_rewards_all_match': np.mean(purity_rewards_all_match), 'purity_rewards_by_num_users': np.mean(purity_rewards_by_num_users), 'precision_f1': np.mean(precision_rewards_all_match), 'recall_f1': np.mean(recall_rewards_all_match), 'f1_f1': np.mean(f1_rewards_all_match)}, success_predictions_dict, failure_predictions_dict, num_users_predicted
    return {'analysis_reward': reward_to_return, 'purity_reward': np.mean(purity_rewards), 'bias_reward': np.mean(gpt_bias_rewards), 'perspective_reward': np.mean(perspective_rewards), 'purity_rewards_all_match': np.mean(purity_rewards_all_match), 'purity_rewards_by_num_users': np.mean(purity_rewards_by_num_users), 'precision_f1': np.mean(precision_rewards_all_match), 'recall_f1': np.mean(recall_rewards_all_match), 'f1_f1': np.mean(f1_rewards_all_match)}, num_users_predicted