import numpy as np
import pandas as pd

import sys 
sys.path.append('..')
sys.path.append('../../')
import llm_prompts
from convokit import Corpus, download
import random
import GNN_model_commented
import jsonlines
from flair.data import Sentence
from flair.models import SequenceTagger
import datetime
import pdb
import time
from tqdm import tqdm
from collections import defaultdict
import gc
import os 

def generate_focus_area_given_gold_comms(comm_1, comm_2, entity_discussed, all_usernames_summaries, text_input, use_llama_for_gold_focus_areas=False, specific_prompt_question_to_ask=None):
    
    if np.random.randint(0, 2):
        # half the time, swap it around so we learn both republican and liberal
        tmp = comm_2
        comm_2 = comm_1
        comm_1 = tmp
    
    # first build the text list
    text_prompt = ''
    comm_1_valid_users = []
    comm_2_valid_users = []
    for given_user in comm_1:
        if given_user in all_usernames_summaries:
            text_prompt += 'Username: ' + all_usernames_summaries[given_user].replace('\n', '') + '\n'
            comm_1_valid_users.append(given_user)
    for given_user in comm_2:
        if given_user in all_usernames_summaries:
            text_prompt += 'Username: ' + all_usernames_summaries[given_user].replace('\n', '') + '\n'
            comm_2_valid_users.append(given_user)
            
    non_user_specific_prompt_needed = False
    if len(comm_1_valid_users) != len(comm_2_valid_users):
        if len(comm_1_valid_users) > 2 and len(comm_2_valid_users) > 2:
            max_user_len = min(len(comm_1_valid_users), len(comm_2_valid_users))
            comm_1_valid_users = comm_1_valid_users[:max_user_len]
            comm_2_valid_users = comm_2_valid_users[:max_user_len]
        if len(comm_1_valid_users) != len(comm_2_valid_users):
            # some user wasn't found, move on 
            print("len(comm_1_valid_users) is " + str(len(comm_1_valid_users)) + " and len(comm_2_valid_users) is " + str(len(comm_2_valid_users)))
            print("Users don't match length for " + str(comm_1) + " and " + str(comm_2) + " and here are summaries")
            print(all_usernames_summaries)
            print("text_input " + str(text_input))
            non_user_specific_prompt_needed = True 
            return None
    
    if len(comm_1_valid_users) == 2:
        num_users = 'two'
    elif len(comm_1_valid_users) == 3:
        num_users = 'three'
    elif len(comm_1_valid_users) == 4:
        num_users = 'four'
    else:
        print("We have these many users! " + str(len(comm_1_valid_users)))
        exit(1)
    
    need_to_rerun = True
    times_rerun = 0
    while need_to_rerun:
        prompt_question_to_ask = "What topics should we focus on to determine that the first " + str(num_users) + " users are in the same community while others are not? Consider the topic or entity that everyone is discussing and has different perspectives on, as that will allow us to separate the users. Also respond with what aspects of that topic or entity are important. Only respond in a SINGLE complete sentence with the topics. Do not respond in a list, and include a maximum of one topic. Do NOT include any mention to the 'first user' or 'first two users' or 'first three users' or 'first four users' or any of the users in general! Your response should not include information that reveals that I asked you about the first " + str(num_users) + " users. Your response should start with 'Focus on'"
        
        if times_rerun > 2:
            break
        
        if non_user_specific_prompt_needed or times_rerun > 1:
            # we already tried a rerun and it didn't work.
            prompt_question_to_ask = "What topic are all of these users focusing on? Provide a a single topic, with a sentence and explain it. DO NOT PROVIDE MULTIPLE TOPICS. Be as specific as possible. If there is no topic or it is unclear, respond with 'UNKNOWN'"
        
        if specific_prompt_question_to_ask is not None:
            prompt_question_to_ask = specific_prompt_question_to_ask
        
        given_input = text_prompt + prompt_question_to_ask
        
        message_list = []
        curr_message = {"role": "user", "content": given_input}
        message_list.append(curr_message)  

        openai_response = llm_prompts.prompt_openai_and_get_response(message_list=message_list, prompt_local_model=use_llama_for_gold_focus_areas, openai_url="http://128.10.12.202:47526/v1")

        if openai_response is None:
            need_to_rerun = True 
            times_rerun += 1
            continue
        
        if ('first ' + str(num_users) + 'users') in openai_response or (str(num_users) in openai_response):
            # rerun
            if specific_prompt_question_to_ask is not None:
                return None
            
            prompt_question_to_ask += '\n Remmber, do not include any mention to the ' + 'first ' + str(num_users) + 'users. Your response should not include information that reveals that I asked you about the first " + str(num_users) + " users.'
            need_to_rerun = True
            times_rerun += 1
        else:
            need_to_rerun = False
    
    return openai_response
            
def get_existing_entities_users(filename_with_existing_data, existing_entities, existing_users):
    
    all_data = []
    # Read the data from the JSONL file and append each dictionary to the list
    with jsonlines.open(filename_with_existing_data, 'r') as fp:
        for item in fp:
            all_data.append(item)
            
    for item in all_data:
        if item['entity'] not in existing_entities:
            existing_entities[item['entity']] = 1
        for given_user in item['comm_1']:
            if given_user not in existing_users:
                existing_users[given_user] = 1
    
    return existing_entities, existing_users
            
def get_top_users_given_corpus(given_corpus, top_users):
    '''given a corpus, determine the users '''
    
    user_count = {}
    for given_conversation_id in given_corpus.get_conversation_ids():
        # go through the posts (conversations)
        given_conversation = given_corpus.get_conversation(given_conversation_id)
        for given_utterance in given_conversation.iter_utterances():
            # go through the utterances
            curr_speaker = given_utterance.get_speaker()
            curr_speaker_name = curr_speaker.id
            if curr_speaker_name not in user_count:
                user_count[curr_speaker_name] = 0
            # add the score if that is what we are going for
            # user_count[curr_speaker_name] = user_count[curr_speaker_name] + given_utterance.meta['score']
            if given_utterance.meta['score'] >= 1:
                user_count[curr_speaker_name] = user_count[curr_speaker_name] + 1
            
    for given_user, score_for_it in user_count.items():
        if score_for_it >= 2:
            # user has at least 2 posts that don't have a negative or 0 score
            top_users[given_user] = 1
            
    return top_users

def previous_weekday(d, weekday):

    try:
        days_behind = d.weekday() - weekday
        # go back that many days to get to the weekday
        return d - datetime.timedelta(days_behind)
    except Exception as e:
        print("Some problem converting the date " + str(e))
        return None

def extract_entities_from_text(text, flair_tagger=None, ner_tagger=None, working_with_reddit=False):

    '''given a text and a tagger, extract the entities and add it to the dictionary provided. 

    entities_dict can be empty or have the entities that have been extracted from previous text spans

    text_to_link_entity_article is the text that you want to be linked to the entity
    '''
    import flair, torch
    if not working_with_reddit:
        flair.device = torch.device('cpu') 
    
    common_entities = ['user', 'perspective', 'issues', 'discussing', 'topic', 'users', 'issue']
    
    # start_time = time.time()
    # load the NER tagger
        
    # # run NER over sentence
    curr_entities_dict = {}
    
    if ner_tagger is not None:
        doc = ner_tagger(str(text))
        concepts = [token.text for token in doc if token.pos_ in ["NOUN", "PROPN"]]
        for entity in concepts:
            curr_entities_dict[entity.lower()] = 1


    if ner_tagger is None:
        if not working_with_reddit:
            flair_tagger = flair_tagger.to('cpu')
        
        sentence_to_tag = Sentence(text)
        flair_tagger.predict(sentence_to_tag)
        for entity in sentence_to_tag.get_spans('ner'):
            # add each predicted entity to the output entities dict
            curr_entities_dict[entity.text.lower()] = 1
    
    # remove the common entities from curr_entities_dict
    for given_entity in common_entities:
        if given_entity in curr_entities_dict:
            del curr_entities_dict[given_entity]
            
    # end_time = time.time()
    # execution_time = end_time - start_time
    # print("It takes this long to get the entities " + str(execution_time))
    gc.collect()

    return curr_entities_dict

def get_dates_corpus(given_corpus):
    
    # for each conversation in the corpus get, the posts that it came from 
    # checked_convos = 0
    # skipped_convos = 0
    dates_corpus = []
    all_conv_ids = given_corpus.get_conversation_ids()
    for given_conversation_id in all_conv_ids:
        
        # go through the posts (conversations)
        given_conversation = given_corpus.get_conversation(given_conversation_id)
        current_timestamp = given_conversation.meta['timestamp']
        # convert to datetime
        current_date_object = datetime.datetime.utcfromtimestamp(current_timestamp)
        dates_corpus.append(current_date_object)
        
    return dates_corpus


def get_focus_area(num_users, pos_users, neg_users, random_entity, all_user_summaries, curr_text, use_llama_for_gold_focus_areas=False):
    # an example of how to get the focus areas to train the supervised learning model
    # input: num_users is total number of positive users (), pos_users is users from comm_1, neg_users is users from comm_2, random_entity is the entity they have in common, all_user_summaries is a dict of users and their summaries as value, curr_text is the full text input that will be provided to LLM_task 

    prompt_question_to_ask = "What three topics/entities should we focus on to determine that the first " + str(num_users) + " users are in the same community while others are not, and in your response only mention the three topics/entities in a SINGLE sentence, with no explanation of why you chose those topics and no mention of `first " + str(num_users) + " users' and the word community? Only respond in a SINGLE complete sentence that is no longer than 20 words with the topics and the perspectives, no other explanation. Do not respond in a list. Remmber, do not include any mention to the first " + str(num_users) + " users, or any of the usernames. Your response should not include information that reveals that I asked you about the first " + str(num_users) + " users.' Your response should start with 'Focus on the topics' and not include include any reasoning or explanation, such as 'to determine' or 'to understand' or the word 'users' or the word 'community' or `first " + str(num_users) + " users' or 'first one user' or 'first two users'." 
    curr_focus_area = generate_focus_area_given_gold_comms(comm_1=pos_users, comm_2=neg_users, entity_discussed=random_entity.split('__date__')[0].replace('_', ' '), all_usernames_summaries=all_user_summaries, text_input=curr_text, use_llama_for_gold_focus_areas=use_llama_for_gold_focus_areas, specific_prompt_question_to_ask=prompt_question_to_ask)