import pickle
from tqdm import tqdm
from collections import Counter
import numpy as np
import torch
from Utils.utils import *
import argparse

def main(data_path):
    ## read preprocessed files
    with open(data_path + '/CSFCube_phrase2integrity_dict', 'rb') as f:
        phrase2integrity_dict = pickle.load(f)
        
    with open(data_path + '/CSFCube_phrase2doc_dic', 'rb') as f:
        phrase2doc_dic = pickle.load(f)
        
    with open(data_path + '/CSFCube_doc2phrase_dic', 'rb') as f:
        doc2phrase_dic = pickle.load(f)
        
    with open(data_path + '/doc2class_dict', 'rb') as f:
        doc2class_dict = pickle.load(f)
        
    with open(data_path + '/CSFCube_cid_list', 'rb') as f:
        cid_list = pickle.load(f)

    num_class = 1164
    doc_class_mat = torch.zeros((len(doc2class_dict), num_class))

    for idx, cid in enumerate(cid_list):
        for topic_id in doc2class_dict[cid]:
            doc_class_mat[idx][topic_id] = 1
            
    doc_score_mat = torch.matmul(doc_class_mat, doc_class_mat.T)
    doc_hardneg_mat = torch.argsort(-doc_score_mat, axis=-1)[:,:100]

    doc2hardnegdoc_dic = {}
    for idx, cid in enumerate(cid_list):
        doc2hardnegdoc_dic[cid] = [cid_list[x] for x in doc_hardneg_mat[idx]]
        
    doc2finegrained = {}
    phrase_set = Counter({})

    for cid in tqdm(cid_list):
        c = Counter(doc2phrase_dic[cid])
            
        hd_c = Counter({})
        for negdoc_id in doc2hardnegdoc_dic[cid]:
            hd_c = hd_c + Counter({x: 1 for x in doc2phrase_dic[negdoc_id]})
        
        c = dict(c)
        hd_c = dict(hd_c)

        for term in c:
            c[term] = distinveness(term, c, hd_c) * phrase2integrity_dict[term]

        new_tmp = sorted(c, key=lambda x: -c[x])
        new_tmp = new_tmp[:min(len(new_tmp)//5, 15)]
        new_tmp = omit_substrings(new_tmp)
        
        doc2finegrained[cid] = {}
        for term in new_tmp:
            doc2finegrained[cid][term] = c[term]
            
        phrase_set.update(new_tmp)

    ## re-indexing phrases
    phrase2phrase_idx_dict = {}
    phrase_idx2phrase_dict = {}

    for phrase in sorted(list(phrase_set.keys()), key = lambda x: -len(x)):
        if phrase + 's' in phrase2phrase_idx_dict:
            phrase_idx = phrase2phrase_idx_dict[phrase + 's']
        elif phrase + 'es' in phrase2phrase_idx_dict:
            phrase_idx = phrase2phrase_idx_dict[phrase + 'es'] 
        else:
            phrase_idx = len(phrase_idx2phrase_dict)
        
        phrase2phrase_idx_dict[phrase] = phrase_idx
        if phrase_idx not in phrase_idx2phrase_dict:
            phrase_idx2phrase_dict[phrase_idx] = []
            
        phrase_idx2phrase_dict[phrase_idx].append(phrase)
        
    doc2phrase_dict = {}
    for cid in doc2finegrained:
        doc2phrase_dict[cid] = []
        for phrase in doc2finegrained[cid]:
            if phrase not in phrase2phrase_idx_dict: continue
            doc2phrase_dict[cid].append(phrase2phrase_idx_dict[phrase])
            
    with open(data_path + '/phrase2phrase_idx_dict', 'wb') as f:
        pickle.dump(phrase2phrase_idx_dict, f)
        
    with open(data_path + '/phrase_idx2phrase_dict', 'wb') as f:
        pickle.dump(phrase_idx2phrase_dict, f)
        
    with open(data_path + '/doc2phrase_dict', 'wb') as f:
        pickle.dump(doc2phrase_dict, f)


if __name__ == '__main__':
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path', default='./Dataset/CSFCube', type=str)
    args = parser.parse_args()

    main(args.data_path)