############ GPT-4 Assign Keyword ##################

import os
import openai
import csv
import json, pickle
import numpy as np
import time
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
from datasets import concatenate_datasets, load_dataset
from tqdm import tqdm
import re


'''
Updates 2/3:
- use gpt4-1106 version.

'''

openai.api_key = 'abcde'

data_name = 'vitaminc'
cluster_size = 100
seed = 3407

save_dir = f'./{data_name}/cluster_subset_q100_newid_clustersize{cluster_size}/'

abstracts_combined = load_dataset('csv', data_files=f'./data/{data_name}_corpus_10k.csv', split='train')

print(f"Loading semantic dict 'semantic_id_dict_seed{seed}.list'...")
semantic_ids, semantic_dict = pickle.load(open(save_dir+f'/semantic_id_dict_seed{seed}.list', 'rb'))




##########################################################################################

def extract_keywords(completion_content):
    pattern = r'#keywords:\s*(.+)'
    match = re.search(pattern, completion_content)

    if match:
        # Extract the keywords and split them by comma
        keywords = match.group(1).split(',')
        # Trim whitespace and filter out empty strings
        return ', '.join(keyword.strip() for keyword in keywords if keyword.strip())
    else:
        # Return an empty string if no keywords are found
        print("error: No keywords found.")
        return ''
    
def estimate_tokens(text):
    # Rough estimation of token count
    return len(text.split())

def dynamic_sampling(cluster, max_token_limit=2048): 
    # the effective contextural range is 2048 to 4096 tokens acoording to GPT4's answer
    
    np.random.shuffle(cluster)
    # model_token_limit=8192
    sampled_sentences = []
    current_token_count = 0

    for sentence in cluster:
        sentence_token_count = estimate_tokens(sentence)
        if current_token_count + sentence_token_count <= max_token_limit:
            sampled_sentences.append(sentence)
            current_token_count += sentence_token_count
        else:
            break  # Stop adding sentences once the limit is reached

    return sampled_sentences





save_dir2 = save_dir + '/formatted_longer/'
os.makedirs(save_dir2, exist_ok=True)
print("Keywords will be saved to: ", save_dir2)
print("--------------------------\n")

# Get keyword for each cluster
keyword_dict = {} # keyword to cluster id mapping
n_samples = 20 # number of samples selected from each cluster to generate keywords

output_file = save_dir2 + f'/cluster_keyword_seed{seed}.txt'

start_time = time.time()
if True:

    with open(output_file, 'w') as file:
        for cluster_id in semantic_dict:
            cluster = abstracts_combined.select(semantic_dict[cluster_id])['text']

            cluster_use = dynamic_sampling(cluster, 3072)
            
            prompt = f"""Given the following texts from Wikipedia articles, please analyze and identify 5 to 10 of keywords that succinctly capture the main topics and key entities. 

            Group of sentences: 
            {cluster_use}

            Output the keywords in the following format:
            #keywords: your keywords here.
            [note: Please extract 5 to 10 keywords only, ensuring they are separated by commas. Focus on extracting terms that would serve as effective search queries or index terms for someone researching this topic.]
            """
            
            while True:
                try:
                    completion = openai.ChatCompletion.create(
                                                    model='gpt-4-1106-preview', 
                                                    messages=[
                                                    {"role": "user", "content": prompt}
                                                    ],
                                                    temperature=0,
                                                    max_tokens=256, # for output only??
                                                    top_p=0,
                                                    frequency_penalty=0,    
                                                    presence_penalty=0)
                    keywords = extract_keywords(completion["choices"][0]["message"]['content'])
                    print(f'Keywords of cluster {cluster_id}:', keywords)
                    
                    if keywords in keyword_dict: # if keywords already exist for other clusters
                        keyword_dict[keywords].append(cluster_id)
                        file.write("Warning: Duplicated keywords detected.")
                    else:
                        keyword_dict[keywords] = [cluster_id]
                    
                    file.write(f"Cluster {cluster_id} keyword '{keywords}'\n")
                    
                    break 
                except Exception as err:
                    print('Exception occurs when testing ChatGPT on harmful instructions:', err)
                    print('Will sleep for ten seconds before retry...')
                    time.sleep(10)
            
        pickle.dump(keyword_dict, open(save_dir2+f'/keyword2clusterid_seed{seed}.dict', 'wb'))



print("--------------------------\n")
print("Time costs: ", time.time()-start_time)
print("Keywords is saved to: ", save_dir2)


keyword_dict = pickle.load(open(save_dir2+f'/keyword2clusterid_seed{seed}.dict', 'rb'))
print(f"number of keywords: {len(keyword_dict.keys())}")

output_file = save_dir2 + f'/cluster_text_seed{seed}.txt'
with open(output_file, 'w') as file:
    for cluster_id in semantic_dict:
        cluster_texts = abstracts_combined.select(semantic_dict[cluster_id])['text']
        cluster_texts = "\n".join(cluster_texts)

        file.write(f"Cluster {cluster_id}:\n {cluster_texts}\n\n")


print(keyword_dict)
