
import jsonlines
import glob 
import sys 
sys.path.append('..')
import numpy as np
import random
import llm_prompts
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

from tqdm import tqdm
import joblib

def get_complex_focus_area_given_simple(simple_focus_area):
    '''Given a more complex/informative focus area given a simpler one '''
    
    message_list = []
    sentence_starting = ' '.join(simple_focus_area.split(' ')[:3])
    output_sentences = []
    all_prompt_to_run_question = []
    prompt_to_run_question = "Given this sentence, give me another sentence that has the same meaning but is more complex and detailed, but not too much. The new sentence should start with '" + sentence_starting + "' and should be at least double the length of the original sentence.\n\n. Here is the sentence: " + str(simple_focus_area)
    all_prompt_to_run_question.append(prompt_to_run_question)
    prompt_to_run_question = "Given this sentence, give me another sentence that has the same meaning but has a few more entities. The new sentence should start with '" + sentence_starting + "', should be around the same lenghth as the original sentence, and mention at least three entities.\n\n. Here is the sentence: " + str(simple_focus_area)
    all_prompt_to_run_question.append(prompt_to_run_question)
    prompt_to_run_question = "Given this sentence, give me another sentence that has the same meaning but has a few more news events and topics. The new sentence should start with '" + sentence_starting + "', should be at least double the length of the original sentence, and should mention at least three events or topics.\n\n. Here is the sentence: " + str(simple_focus_area)
    all_prompt_to_run_question.append(prompt_to_run_question)
    prompt_to_run_question = "Given this sentence, give me another sentence that has the same meaning but mentions an additional entity or topic. The new sentence should start with '" + sentence_starting + "'.\n\n. Here is the sentence: " + str(simple_focus_area)
    all_prompt_to_run_question.append(prompt_to_run_question)
    prompt_to_run_question = "Given this sentence, give me another sentence that has the same meaning but mentions an additional entity or topic. The new sentence should start with '" + sentence_starting + "' and be the same length as the original sentence, but you can add 2 words.\n\n. Here is the sentence: " + str(simple_focus_area)
    prompt_to_run_question = "Given this sentence, give me another sentence that is more detailed about the topic or entities mentioned. The new sentence should start with '" + sentence_starting + "' and be the same length as the original sentence, but you can add 2 words.\n\n. Here is the sentence: " + str(simple_focus_area)
    all_prompt_to_run_question.append(prompt_to_run_question)
    
    for prompt_to_run_question in all_prompt_to_run_question:
        message_list = []
        curr_message = {"role": "user", "content": prompt_to_run_question}
        message_list.append(curr_message)
        output_sentences.append(llm_prompts.prompt_openai_and_get_response(message_list, prompt_chat_gpt=True, prompt_local_model=False))
    
    return output_sentences


def get_complex_focus_area_given_simple_comm_difference(simple_focus_area, comm_1, comm_2, all_usernames_summaries):
    
    if np.random.randint(0, 2):
        # half the time, swap it around so we learn both republican and liberal
        tmp = comm_2
        comm_2 = comm_1
        comm_1 = tmp
    
    # first build the text list of the usernames and their summaries
    text_prompt = ''
    comm_1_valid_users = []
    comm_2_valid_users = []
    all_valid_users = []
    for given_user in comm_1:
        if given_user in all_usernames_summaries:
            text_prompt += 'Username: ' + all_usernames_summaries[given_user].replace('\n', '') + '\n'
            comm_1_valid_users.append(given_user)
    for given_user in comm_2:
        if given_user in all_usernames_summaries:
            text_prompt += 'Username: ' + all_usernames_summaries[given_user].replace('\n', '') + '\n'
            comm_2_valid_users.append(given_user)
    all_valid_users.extend(comm_1_valid_users)
    all_valid_users.extend(comm_2_valid_users)

    
    # now, ask the question
    message_list = []
    sentence_starting = ' '.join(simple_focus_area.split(' ')[:3])
    output_sentences = []
    all_prompt_to_run_question = []
    prompt_to_run_question = "Given this sentence, users " + str(', '.join(all_valid_users)) + " write another sentence that expands the sentence and adds more details that these users discuss: " + str(', '.join(comm_1_valid_users)) + ", but these users do not: " + str(','.join(comm_2_valid_users)) + ". Only respond in a SINGLE complete sentence, making the original sentence more detailed. Your response should NOT include the word 'user' or 'users', or reference to the fact we are discussing groups of users. MOST OF ALL, DO NOT include any of these users in your response: " + str(', '.join(all_valid_users)) + "! Your response should not include information that reveals that I asked you about users. The new sentence should start with '" + sentence_starting + "'.\n. Here is the original sentence: " + str(simple_focus_area)
    all_prompt_to_run_question.append(prompt_to_run_question)
    prompt_to_run_question = "Given this sentence, users " + str(', '.join(all_valid_users)) + " write another sentence that expands the sentence and adds three topics that these users discuss: " + str(', '.join(comm_1_valid_users)) + ", but these users do not: " + str(','.join(comm_2_valid_users)) + ". Only respond in a SINGLE complete sentence, making the original sentence more detailed. Your response should NOT include the word 'user' or 'users', or reference to the fact we are discussing groups of users. MOST OF ALL, DO NOT include any of these users in your response: " + str(', '.join(all_valid_users)) + "! Your response should not include information that reveals that I asked you about users. The new sentence should start with '" + sentence_starting + "'.\n. Here is the original sentence: " + str(simple_focus_area)
    all_prompt_to_run_question.append(prompt_to_run_question)
    prompt_to_run_question = "Given this sentence, users " + str(', '.join(all_valid_users)) + " write another sentence that expands the sentence and adds at least 5 more topics or entities that these users discuss: " + str(', '.join(comm_1_valid_users)) + ", but these users do not: " + str(','.join(comm_2_valid_users)) + ". Only respond in a SINGLE complete sentence, making the original sentence more detailed. Your response should NOT include the word 'user' or 'users', or reference to the fact we are discussing groups of users. MOST OF ALL, DO NOT include any of these users in your response: " + str(', '.join(all_valid_users)) + "! Your response should not include information that reveals that I asked you about users. The new sentence should start with '" + sentence_starting + "'.\n. Here is the original sentence: " + str(simple_focus_area)
    all_prompt_to_run_question.append(prompt_to_run_question)

    for prompt_to_run_question in all_prompt_to_run_question:
        message_list = []
        curr_message = {"role": "user", "content": prompt_to_run_question}
        message_list.append(curr_message)
        output_sentences.append(llm_prompts.prompt_openai_and_get_response(message_list, prompt_chat_gpt=True, prompt_local_model=False))
    
    return output_sentences



def create_dataset_from_list_of_summaries_focus_areas(gold_focus_area, rewritten_focus_area):
    
    return_list = []
    for focus_area in gold_focus_area:
        data = {}
        data['text'] = focus_area
        data['label'] = 1
        return_list.append(data)
    for focus_area in rewritten_focus_area:
        data = {}
        data['text'] = focus_area
        data['label'] = 0
        return_list.append(data)
    random.shuffle(return_list)
    return return_list

def split_data(train_ratio, dev_ratio, test_ratio, list1, list2):
    assert len(list1) == len(list2), "Input lists must have the same length"

    data = list(zip(list1, list2))
    random.shuffle(data)

    total_length = len(data)
    train_length = int(total_length * train_ratio)
    dev_length = int(total_length * dev_ratio)
    test_length = total_length - train_length - dev_length
    
    if train_ratio == 0:
        train_set = []
    else:
        train_set = data[:train_length]
    if dev_ratio == 0:
        dev_set = []
    else:
        dev_set = data[train_length:train_length + dev_length]
    test_set = data[train_length + dev_length:]

    train_list1, train_list2 = zip(*train_set) if train_set else ([], [])
    dev_list1, dev_list2 = zip(*dev_set) if dev_set else ([], [])
    test_list1, test_list2= zip(*test_set)
    
    train_data = create_dataset_from_list_of_summaries_focus_areas(train_list1, train_list2)
    val_data = create_dataset_from_list_of_summaries_focus_areas(dev_list1, dev_list2)
    test_data = create_dataset_from_list_of_summaries_focus_areas(test_list1, test_list2)

    return train_data, val_data, test_data

def create_dataset(path_to_save_data):
    # read the data and generate the new datasets
    for file_to_read in glob.glob(path_to_save_data):
        
        if 'test' in file_to_read:
            continue
        
        # set up the output file path
        original_file_name = file_to_read.split('/')[-1]
        output_file_path = str(file_to_read.replace(original_file_name, '')) + 'simple_complex_dataset_' +  original_file_name
        
        ## load the data like fix_json_gold_comms
        all_data = []
        # Read the data from the JSONL file and append each dictionary to the list
        with jsonlines.open(file_to_read, 'r') as fp:
            for item in fp:
                all_data.append(item)
                
        
                
        # get all the gold focus areas
        gold_focus_areas = []
        for item in all_data:
            gold_focus_areas.append(item['gold'])
            
        all_usernames_summaries = {}   
        for given_message in item['text'].split('\n'):
            if len(given_message.strip().rstrip()) == 0:
                continue
            if 'username:' in given_message:
                username_text = given_message.split('username: ')
            elif 'Username:' in given_message:
                username_text = given_message.split('Username: ')
            else:
                username_text = None
            
            if username_text is not None:
                if 'summary' in username_text:
                    curr_username = username_text[1].split('summary:')[0]  
                elif 'Summary' in given_message:
                    curr_username = username_text[1].split('Summary:')[0]  
                else:
                    curr_username = username_text[1]
            elif given_message.split(' ')[0].lower() != 'which':
                curr_username = given_message.split(' ')[0]
            else:
                continue
                
            if len(curr_username.split(' ')) > 0:
                # only consider the first word
                curr_username = curr_username.split(' ')[0]
                
            curr_username_summary = given_message
                
            all_usernames_summaries[curr_username] = curr_username_summary

        # for each focus area, ask Chat GPT to make it more complex
        focus_areas_complex = []
        for given_focus_area in tqdm(gold_focus_areas):
            
            complex_focus_area_comm_difference = get_complex_focus_area_given_simple_comm_difference(given_focus_area, comm_1=item['comm_1'], comm_2=item['comm_2'], all_usernames_summaries=all_usernames_summaries)
            focus_areas_complex.extend(complex_focus_area_comm_difference)
            
        
        # save the new data
        new_dataset = create_dataset_from_list_of_summaries_focus_areas(gold_focus_areas, focus_areas_complex)
        print("Saving at file path " + str(output_file_path))
        with jsonlines.open(output_file_path, 'w') as fp:
            for item in new_dataset:
                fp.write(item)
    
def get_data():
    
    # first load the train, test, validation dicts
    # train a simple regression model for it 
    train_data, validation_data, test_data = [], [], []
    for file_to_read in glob.glob(path_to_save_data):
        
        # set up the output file path
        original_file_name = file_to_read.split('/')[-1]
        output_file_path = str(file_to_read.replace(original_file_name, '')) + 'simple_complex_dataset_' +  original_file_name
        print("Working with file " + str(output_file_path))
        if 'test' in output_file_path.replace('training', ''):
            with jsonlines.open(output_file_path, 'r') as fp:
                for item in fp:
                    test_data.append(item)
        elif 'val' in output_file_path.replace('training', ''):
            print("We have validation")
            with jsonlines.open(output_file_path, 'r') as fp:
                for item in fp:
                    validation_data.append(item)
        elif 'train' in output_file_path.replace('training', ''):
            with jsonlines.open(output_file_path, 'r') as fp:
                for item in fp:
                    train_data.append(item)
        
                    
    print("We have this much train, test, validation data: " + str(len(train_data)) + " " + str(len(test_data)) + " " + str(len(validation_data)))
        
    # Extract features and labels
    X_train = [item['text'] for item in train_data]
    y_train = [item['label'] for item in train_data]
    
    print("Samples from the training set:")
    simple_count = 0
    complex_count = 0
    for text, label in zip(X_train, y_train):
        if label == 0 and simple_count < 5:
            print(f"Simple: {text} (Label: {label})")
            simple_count += 1
        elif label == 1 and complex_count < 5:
            print(f"Complex: {text} (Label: {label})")
            complex_count += 1

        if simple_count == 5 and complex_count == 5:
            break

    X_test = [item['text'] for item in test_data]
    y_test = [item['label'] for item in test_data]

    X_val = [item['text'] for item in validation_data]
    y_val = [item['label'] for item in validation_data]

    # Feature extraction using TF-IDF vectorization
    vectorizer = TfidfVectorizer()
    X_train_vec = vectorizer.fit_transform(X_train)
    X_test_vec = vectorizer.transform(X_test)
    X_val_vec = vectorizer.transform(X_val)
    
    return X_train, y_train, X_test, y_test, X_val, y_val, X_train_vec, X_test_vec, X_val_vec, vectorizer
    
    
    
def train_regression_model(X_train, y_train, X_test, y_test, X_val, y_val, X_train_vec, X_test_vec, X_val_vec, vectorizer):
# Model training
    model = LogisticRegression()
    model.fit(X_train_vec, y_train)
    # save the model where it can be loaded
    joblib.dump(model, 'complexity_model.pkl')
    joblib.dump(vectorizer, 'complexity_vectorizer.pkl')
    
    print("Saved the model")
    
    X_test_np = X_test_vec
    y_test_np = y_test
    X_val_np = X_val_vec
    y_val_np = y_val

    # Model evaluation on the test set
    y_pred = model.predict(X_test_np)
    y_pred_binary = y_pred
    accuracy = accuracy_score(y_test_np, y_pred_binary)
    print(f"Test Accuracy: {accuracy}")

    # Model evaluation on the validation set
    y_pred = model.predict(X_val_np)
    y_pred_binary = y_pred
    accuracy_val = accuracy_score(y_val_np, y_pred_binary)
    print(f"Validation Accuracy: {accuracy_val}")

if __name__ == '__main__':
    
    parse_files = True
    if parse_files:
        create_dataset(path_to_save_data='../Dataset_Release/*.jsonl')
    
    X_train, y_train, X_test, y_test, X_val, y_val, X_train_vec, X_test_vec, X_val_vec, vectorizer = get_data()
    train_regression_model = True 
    if train_regression_model:
        train_regression_model(X_train, y_train, X_test, y_test, X_val, y_val, X_train_vec, X_test_vec, X_val_vec, vectorizer)