import pickle
import torch
import os
import numpy as np
from tqdm import tqdm
import copy 
import argparse

from Utils.utils import *

def main(data_path):
    ## read preprocessed files
    with open(data_path + '/CSFCube_class_dict', 'rb') as f:
        class_dict = pickle.load(f)
        
    with open(data_path + '/CSFCube_merged_class_dict', 'rb') as f:
        merged_class_dict = pickle.load(f)
        
    with open(data_path + '/CSFCube_topicid2topic_dict', 'rb') as f:
        topicid2topic_dict = pickle.load(f)
        
    with open(data_path + '/CSFCube_topic2topicid_dict', 'rb') as f:
        topic2topicid_dict = pickle.load(f)

    with open(data_path + '/CSFCube_cid_list', 'rb') as f:
        cid_list = pickle.load(f)
        
    with open(data_path + '/CSFCube_qid_list', 'rb') as f:
        qid_list = pickle.load(f)
        
    with open(data_path + '/corpus', 'rb') as f:
        corpus = pickle.load(f)
        
    with open(data_path + '/CSFCube_phrase2integrity_dict', 'rb') as f:
        phrase2integrity_dict = pickle.load(f)
        
    with open(data_path + '/CSFCube_phrase2doc_dic', 'rb') as f:
        phrase2doc_dic = pickle.load(f)
        
    with open(data_path + '/CSFCube_doc2phrase_dic', 'rb') as f:
        doc2phrase_dic = pickle.load(f)

    ### BERT embeddings for documents and topics should be located in the following directory. 
    corpus_emb = torch.load(data_path + '/CSFCube_corpus_emb.pt')
    topic_embs = torch.load(data_path + '/CSFCube_topic_emb.pt')


    # Candidate generation
    docid2topicids = {}
    topicids2docids = {}

    for docid in cid_list:
        docid2topicids[docid] = {}
        
    for topicid in topicid2topic_dict:
        topicids2docids[topicid] = {}

    for docnumber in tqdm(range(len(cid_list))):
        docid = cid_list[docnumber]
        parent_topics = ['computer science']
        visited = set()
        
        while len(parent_topics) > 0:
            parent_topic = parent_topics.pop(0)
            
            if parent_topic in visited: continue   
            is_continue, result = single_level_assignment(docnumber, parent_topic, cid_list, class_dict, merged_class_dict, topic2topicid_dict, doc2phrase_dic, corpus_emb, topic_embs)

            if is_continue:
                next_topic_id = topic2topicid_dict[result[0][0]]
                docid2topicids[docid][next_topic_id] = result[0][1]
                topicids2docids[next_topic_id][docid] = result[0][1]
                parent_topics.append(topicid2topic_dict[next_topic_id])

                if len(result) > 1:
                    next_topic_id = topic2topicid_dict[result[1][0]]
                    docid2topicids[docid][next_topic_id] = result[1][1]
                    topicids2docids[next_topic_id][docid] = result[1][1]   
                    parent_topics.append(topicid2topic_dict[next_topic_id])
            visited.add(parent_topic)
            
    docid2topicids_e = copy.deepcopy(docid2topicids)
    topicids2docids_e = copy.deepcopy(topicids2docids)

    for docnumber in tqdm(range(len(cid_list))):
        docid = cid_list[docnumber]
        common_topics = set(doc2phrase_dic[docid].keys()) & set(class_dict.keys())
        for common_topic in common_topics:
            topicid = topic2topicid_dict[common_topic]
            if topicid not in docid2topicids_e[docid]:
                docid2topicids_e[docid][topicid] = 1., -1
                topicids2docids_e[topicid][docid] = 1.
                
    with open(data_path + '/CSFCube_docid2topicids', 'wb') as f:
        pickle.dump(docid2topicids_e, f)
        
    with open(data_path + '/CSFCube_topicids2docids', 'wb') as f:
        pickle.dump(topicids2docids_e, f)


    # Core topic selection using LLMs
    with open('topic_selection_prompt.txt', 'r') as f:
        instruction = f.read()

    input_dict = {}
    for idx, cid in enumerate(corpus):
        input_dict[cid] = {}
        input_dict[cid]['text'] = '\"' + corpus[cid]['text'] + '\"'
        input_dict[cid]['topics'] = \
        '[' + ', '.join(['\"' + topicid2topic_dict[x] + '\"' for x in docid2topicids_e[cid]]) + ']'


    import openai
    from openai import OpenAI
    import os
    import time

    client = OpenAI()

    def api_call(doc, instruction, demos=[], temperature=0.2, model='gpt-3.5-turbo-0125'):
        '''
        doc str: query
        instruction str: system instruction
        demos List((str, str)): demonstrations, if any
        temperature: None for default temprature
        '''
        
        messages = [{"role": "system", "content": instruction}]
        
        for demo_doc, demo_label in demos[::-1]:
            messages.append({"role": "user", "content": demo_doc})
            messages.append({"role": "assistant", "content": demo_label})
        
        messages.append({"role": "user", "content": doc})
        
        timeout_num = 0
        while timeout_num < 3:
            try:
                completion = client.chat.completions.create(
                    model=model,
                    messages=messages,
                    temperature=temperature,
                )
                break
            except openai.APITimeoutError:
                print('Timeout!')
                timeout_num += 1
                continue
            except openai.RateLimitError:
                print('RateLimitError')
                time.sleep(60)
                continue

        if timeout_num >= 3:
            return ''
        
        return completion.choices[0].message.content

    api_result = {}
    for idx, cid in enumerate(input_dict):
        input_text = "Paper: " + input_dict[cid]['text'] + ", Candidate topic set: " + input_dict[cid]['topics']
        call_result = api_call(input_text, instruction)
        api_result[cid] = call_result
        
        if idx % 100 == 0:
            print(idx / len(input_dict))

    with open(data_path + '/core_ranking_GPT', 'wb') as f:
        pickle.dump(api_result, f)

    with open(data_path + '/core_ranking_GPT', 'rb') as f:
        core_ranking_GPT_raw = pickle.load(f)


    # Post-processing (taxonomy pruning)
    from collections import Counter
    th = 3
    topic_set = Counter()
    core_ranking_GPT = {}
    for cid in core_ranking_GPT_raw:
        if '\n2' in core_ranking_GPT_raw[cid]:
            cid_topic_list = [x.split('. ')[1].strip() for x in core_ranking_GPT_raw[cid].split('\n')]
        elif ', ' in core_ranking_GPT_raw[cid]:
            cid_topic_list = [x.strip() for x in core_ranking_GPT_raw[cid].split(', ')]
        else:
            print(cid)
            
        cid_topic_list = [topic.replace('-', ' ') for topic in cid_topic_list]
        
        core_ranking_GPT[cid] = cid_topic_list
        topic_set.update(cid_topic_list)
        
    topic_set = set({x for x, count in topic_set.items() if count >= th})

    filtered_class_dict = {}
    for topic in class_dict:
        topic = topic.replace('-', ' ')
        if topic not in topic_set: continue
        filtered_class_dict[topic] = []
        for child in class_dict[topic]:
            if child not in topic_set: continue
            filtered_class_dict[topic].append(child)
            
    ## Pruned taxonomy relation update
    child2parent_dict = {}

    for parent in filtered_class_dict:
        for child in filtered_class_dict[parent]:
            if child not in child2parent_dict:
                child2parent_dict[child] = []
            child2parent_dict[child].append(parent)
                
    for idx in range(10):
        for child in child2parent_dict:
            if idx <= len(child2parent_dict[child])-1:
                target = child2parent_dict[child][idx]
                if target not in child2parent_dict: continue
                for new_parent in child2parent_dict[target]:
                    if new_parent not in child2parent_dict[child]:
                        child2parent_dict[child].append(new_parent)
                        
    ## Pruned taxonomy re-indexing
    class2class_idx_dict = {}
    class_idx2class_dict = {}

    for topic in sorted(list(topic_set), key = lambda x: -len(x)):
        
        if topic + 's' in class2class_idx_dict:
            class_idx = class2class_idx_dict[topic + 's']
        elif topic + 'es' in class2class_idx_dict:
            class_idx = class2class_idx_dict[topic + 'es'] 
        else:
            class_idx = len(class_idx2class_dict)
        
        class2class_idx_dict[topic] = class_idx
        if class_idx not in class_idx2class_dict:
            class_idx2class_dict[class_idx] = []
            
        class_idx2class_dict[class_idx].append(topic)
        
    doc2class_dict = {}
    for cid in core_ranking_GPT:
        doc2class_dict[cid] = []
        for t_class in core_ranking_GPT[cid]:
            if t_class not in class2class_idx_dict: continue
            doc2class_dict[cid].append(class2class_idx_dict[t_class])
            
    directed_label_graph = {}
    for parent in filtered_class_dict:
        directed_label_graph[class2class_idx_dict[parent]] = [class2class_idx_dict[child] for child in filtered_class_dict[parent]]

    re_child2parent_dict = {}
    for child in child2parent_dict:
        re_child2parent_dict[class2class_idx_dict[child]] = \
        [class2class_idx_dict[parent] for parent in child2parent_dict[child]]
        
    with open(data_path + '/class2class_idx_dict', 'wb') as f:
        pickle.dump(class2class_idx_dict, f)
        
    with open(data_path + '/class_idx2class_dict', 'wb') as f:
        pickle.dump(class_idx2class_dict, f)
        
    with open(data_path + '/doc2class_dict', 'wb') as f:
        pickle.dump(doc2class_dict, f)
        
    with open(data_path + '/directed_label_graph', 'wb') as f:
        pickle.dump(directed_label_graph, f)
        
    with open(data_path + '/re_child2parent_dict', 'wb') as f:
        pickle.dump(re_child2parent_dict, f)


if __name__ == '__main__':
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path', default='./Dataset/CSFCube', type=str)
    args = parser.parse_args()

    main(args.data_path)