import codecs
import json
import pickle
from collections import Counter, deque
from tqdm import tqdm
import argparse
from Utils.utils import *

def main(data_path):

    # Dataset preparation
    pid2abstract = {}

    with codecs.open(data_path + '/abstracts-csfcube-preds.jsonl', 'r', 'utf-8') as absfile:
        for line in absfile:
            injson = json.loads(line.strip())
            pid2abstract[injson['paper_id']] = injson

    with codecs.open(data_path + '/test-pid2anns-csfcube-background.json', 'r', 'utf-8') as fp:
        qpid2pool_bg = json.load(fp)

    with codecs.open(data_path + '/test-pid2anns-csfcube-method.json', 'r', 'utf-8') as fp:
        qpid2pool_mt = json.load(fp)
        
    with codecs.open(data_path + '/test-pid2anns-csfcube-result.json', 'r', 'utf-8') as fp:
        qpid2pool_rt = json.load(fp)
        
    cid_set = set()
    for qid in qpid2pool_bg:
        cid_set.update(qpid2pool_bg[qid]['cands'])
        
    for qid in qpid2pool_mt:
        cid_set.update(qpid2pool_mt[qid]['cands'])
        
    for qid in qpid2pool_rt:
        cid_set.update(qpid2pool_rt[qid]['cands'])
        
    corpus = {}
    for pid in pid2abstract:
        corpus[pid] = {}
        corpus[pid]['text'] = pid2abstract[pid]['title'] + '. ' + ' '.join(pid2abstract[pid]['abstract'])
    len(corpus)

    queries_bg = {}
    qrels_bg = {}

    for qid in qpid2pool_bg:
        queries_bg[qid] = pid2abstract[qid]['title'] + '. ' + ' '.join(pid2abstract[qid]['abstract'])
        candidates = qpid2pool_bg[qid]['cands']
        relevances = qpid2pool_bg[qid]['relevance_adju']
        
        qrels_bg[qid] = {}
        for idx, cand_id in enumerate(candidates):
            score = relevances[idx]
            if score >= 2:
                score = 1
            else:
                score = 0
            qrels_bg[qid][cand_id] = score
            
    queries_mt = {}
    qrels_mt = {}

    for qid in qpid2pool_mt:
        queries_mt[qid] = pid2abstract[qid]['title'] + '. ' + ' '.join(pid2abstract[qid]['abstract'])
        candidates = qpid2pool_mt[qid]['cands']
        relevances = qpid2pool_mt[qid]['relevance_adju']
        
        qrels_mt[qid] = {}
        for idx, cand_id in enumerate(candidates):
            score = relevances[idx]
            if score >= 2:
                score = 1
            else:
                score = 0
            qrels_mt[qid][cand_id] = score
            
    queries_rt = {}
    qrels_rt = {}

    for qid in qpid2pool_rt:
        queries_rt[qid] = pid2abstract[qid]['title'] + '. ' + ' '.join(pid2abstract[qid]['abstract'])
        candidates = qpid2pool_rt[qid]['cands']
        relevances = qpid2pool_rt[qid]['relevance_adju']
        
        qrels_rt[qid] = {}
        for idx, cand_id in enumerate(candidates):
            score = relevances[idx]
            if score >= 2:
                score = 1
            else:
                score = 0
            qrels_rt[qid][cand_id] = score
            
    queries = {}
    qrels = {}

    for qid in queries_bg:
        queries[qid] = queries_bg[qid].split('.')[0]
        if qid not in qrels:
            qrels[qid] = {}
        for pid in qrels_bg[qid]:
            if (pid in qrels[qid]) and (qrels[qid][pid] == 1):
                continue
            else:
                qrels[qid][pid] = qrels_bg[qid][pid]
            
    for qid in queries_mt:
        queries[qid] = queries_mt[qid].split('.')[0]
        if qid not in qrels:
            qrels[qid] = {}
        for pid in qrels_mt[qid]:
            if (pid in qrels[qid]) and (qrels[qid][pid] == 1):
                continue
            else:
                qrels[qid][pid] = qrels_mt[qid][pid]
            
    for qid in queries_rt:
        queries[qid] = queries_rt[qid].split('.')[0]
        if qid not in qrels:
            qrels[qid] = {}
        for pid in qrels_rt[qid]:
            if (pid in qrels[qid]) and (qrels[qid][pid] == 1):
                continue
            else:
                qrels[qid][pid] = qrels_rt[qid][pid]

    with open(data_path + '/queries', 'wb') as f:
        pickle.dump(queries, f)
        
    with open(data_path + '/qrels', 'wb') as f:
        pickle.dump(qrels, f)

    with open(data_path + '/corpus', 'wb') as f:
        pickle.dump(corpus, f)
        
    cid_list = list(corpus.keys())
    qid_list = list(queries.keys())

    with open(data_path + '/CSFCube_cid_list', 'wb') as f:
        pickle.dump(cid_list, f)
        
    with open(data_path + '/CSFCube_qid_list', 'wb') as f:
        pickle.dump(qid_list, f)

    # Phrase preparation
    ## The results of AutoPhrase should be located in the below directory (https://github.com/shangjingbo1226/AutoPhrase)
    with open(data_path + '/AutoPhrase.txt', 'r') as f:
        lines = f.readlines()

    with open('Dataset/stopwords', 'rb') as f:
        stopwords = pickle.load(f)
        
    phrase2integrity_dict = {}
    for line in lines:
        score, phrase = line.strip().split('\t')
        if float(score) < 0.4:
            break

        if phrase in stopwords:
            continue
            
        if len(phrase) < 3: continue
            
        phrase2integrity_dict[phrase] = float(score)
        
    with open(data_path + '/raw_tokenized_train.txt', 'r') as f:
        lines = f.readlines()

    documents = []
    for line in lines:
        documents.append(line.strip().lower().replace(' -', ''))

    phrase2doc_dic = {}
    doc2phrase_dic = {}

    for phrase in tqdm(phrase2integrity_dict):
        phrase2doc_dic[phrase] = {}
        for idx, cid in enumerate(cid_list):
            doc = documents[idx]
            freq = doc.count(phrase)
            if freq > 0:
                phrase2doc_dic[phrase][cid] = freq
                if cid not in doc2phrase_dic:
                    doc2phrase_dic[cid] = {}
                doc2phrase_dic[cid][phrase] = freq
                
    with open(data_path + '/CSFCube_phrase2integrity_dict', 'wb') as f:
        pickle.dump(phrase2integrity_dict, f)
        
    with open(data_path + '/CSFCube_phrase2doc_dic', 'wb') as f:
        pickle.dump(phrase2doc_dic, f)
        
    with open(data_path + '/CSFCube_doc2phrase_dic', 'wb') as f:
        pickle.dump(doc2phrase_dic, f)

    # Taxonomy preparation
    def BFS(root):
        visited = set()
        result = {}
        Q = deque()
        Q.append(root)

        while len(Q) > 0:
            node = Q.popleft()
            if node in visited: continue
            visited.add(node)
            
            if node in parent_childs_dict:
                result[node] = []
                for child in parent_childs_dict[node]:
                    if child in visited: continue
                    result[node].append(child)
                    Q.append(child)
                    
        return result, visited

    def DFS(node, visited):
        
        if node in visited:
            return []

        result = [node]
        visited[node] = True

        if not (node in class_dict):
            return result 
        
        for neighbor in class_dict[node]:
            if not (neighbor in visited):
                result += DFS(neighbor, visited)
        return result

    ## read taxonomy file
    id2term_dict = {}
    term2id_dict = {}
    with open('Dataset/mag_field_of_studies.terms', 'r') as f:
        for line in f.readlines():
            term_id, term = line.strip().split('\t')
            id2term_dict[term_id] = term
            term2id_dict[term] = term_id
        
    parent_childs_dict = {}
    child_parent_dict = {}
    with open('Dataset/mag_field_of_studies.taxo', 'r') as f:
        for line in f.readlines():
            parent_id, child_id = line.strip().split('\t')
            if parent_id not in parent_childs_dict:
                parent_childs_dict[parent_id] = [child_id]
            else:
                parent_childs_dict[parent_id].append(child_id)
                
            if child_id not in child_parent_dict:
                child_parent_dict[child_id] = [parent_id]
            else:
                child_parent_dict[child_id].append(parent_id)

    parent_id = '41008148'
    taxo, cat_terms = BFS(parent_id)

    for leaf_topic in (cat_terms - set(taxo.keys())):
        taxo[leaf_topic] = {}

    with open(data_path + '/taxo.txt', 'w') as f:
        for parent_id in taxo:
            parent_topic = id2term_dict[parent_id]
            f.write(parent_topic + "\t" + "\t".join([id2term_dict[x] for x in taxo[parent_id]]) + "\n")
            
    ## reindexing
    class_dict = {} 
    term2id_dict = {}
    id2term_dict = {}

    with open(data_path + '/taxo.txt', 'r') as f:
        for line in f.readlines():
            classes = line.strip().split('\t')
            
            for t_class in classes:
                if t_class not in term2id_dict:
                    term2id_dict[t_class] = len(term2id_dict)
                    id2term_dict[term2id_dict[t_class]] = t_class
            
            parent_class, child_classes = classes[0], classes[1:]
            child_classes = list(set(child_classes) - set([parent_class]))
            
            if parent_class in class_dict:
                class_dict[parent_class] += child_classes
                class_dict[parent_class] = list(set(class_dict[parent_class]))
            else:
                class_dict[parent_class] = child_classes

    topicid2topic_dict = {}
    topic2topicid_dict = {}

    for topic in class_dict:
        topicid = len(topic2topicid_dict)
        topicid2topic_dict[topicid] = topic
        topic2topicid_dict[topic] = topicid
        
    with open(data_path + '/CSFCube_topicid2topic_dict', 'wb') as f:
        pickle.dump(topicid2topic_dict, f)
        
    with open(data_path + '/CSFCube_topic2topicid_dict', 'wb') as f:
        pickle.dump(topic2topicid_dict, f)

    with open(data_path + '/CSFCube_class_dict', 'wb') as f:
        pickle.dump(class_dict, f)

    merged_class_dict = {}

    for t_class in class_dict:
        merged_class_dict[t_class] = DFS(t_class, {})
        
    with open(data_path + '/CSFCube_merged_class_dict', 'wb') as f:
        pickle.dump(merged_class_dict, f)


if __name__ == '__main__':
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path', default='./Dataset/CSFCube', type=str)
    args = parser.parse_args()

    main(args.data_path)