# coding: utf-8

# This code is for preprocess the dataset, 
# including preprocess the file to get structured data, compute similarity scores
# and (perhaps) building the graph

import os
import json
import pickle
import nltk
import numpy as np
import scipy.sparse as sp
from collections import Counter, defaultdict
from itertools import combinations


class preprocessor:
    def __init__(self):
        '''Initialize the preprocessor'''
        try:
            self.lemmatizer = nltk.stem.WordNetLemmatizer()
            self.pos_tag = nltk.pos_tag
            self.stopwords = nltk.corpus.stopwords.words('english')

            self.lemmatizer.lemmatize('cats')
            self.pos_tag(['i'])
        except:
            nltk.download('wordnet')
            nltk.download('averaged_perceptron_tagger')
            nltk.download('stopwords')
            self.lemmatizer = nltk.stem.WordNetLemmatizer()
            self.pos_tag = nltk.pos_tag
            self.stopwords = nltk.corpus.stopwords.words('english')
        self.pos_map = {'J':nltk.corpus.wordnet.ADJ, 'V':nltk.corpus.wordnet.VERB, 
                    'N':nltk.corpus.wordnet.NOUN, 'R':nltk.corpus.wordnet.ADV}
        self.stopwords += list(r'.,!:;?/()=-+_[]{}""#$%^&*~`')

    def processSent(self, sentence):
        ''' Preprocess a sentence, including lowercase, lemmatize 
        The sentence should be a list of string words.
        '''
        tagged = self.pos_tag([wrd.lower() for wrd in sentence])
        return [self.lemmatizer.lemmatize(wrd[0], self.pos_map.get(wrd[1][0], 'n')) for wrd in tagged]


    def preprocess(self, filename, saveto=None, load_saved=True, lbl2id_file='data/ontology/onto_ontology.txt'):
        ''' Load dataset in json format. 
        Lower the case and lemmatize, also retrieve the label set and word set.
        '''
        if saveto:
            saveto = os.path.join('tmp', saveto)
            if os.path.exists(saveto) and load_saved:
                with open(saveto, 'rb') as f:
                    save_items = pickle.load(f)
                return save_items

        with open(filename) as f:
            y_str_set = set()
            word_set = set()
            data = []
            for i, line in enumerate(f):
                if i % 100 == 0:
                    print(i, end='\r')
                
                data_i = json.loads(line)
                left = self.processSent(data_i['left_context_token'])
                right = self.processSent(data_i['right_context_token'])

                entity_mention_list = self.processSent(data_i['mention_span'].split())
                # entity_mention = ' '.join(entity_mention_list)
                y_str = data_i['y_str']
                sent = left + entity_mention_list + right

                # collect data
                data.append((sent, (len(left), len(left)+len(entity_mention_list)), y_str))

                sub_word_set = set(sent)
                word_set = word_set.union(sub_word_set)
                y_str_set = y_str_set.union(set(y_str))
            print('Finished loading from file:', filename)

        
        wrd2id = {wrd: id_ for (id_, wrd) in enumerate(list(word_set))}
        # lbl2id = {lbl: id_ for (id_, lbl) in enumerate(list(y_str_set))}
        with open(lbl2id_file) as f:
            lbl2id = {j:i for i,j in enumerate(f.read().strip().split())}

        if saveto:
            with open(saveto, 'wb') as f:
                pickle.dump([data, lbl2id, wrd2id], f)

        return data, lbl2id, wrd2id


    def computeTFIDF(self, docs, wrd2id, use_stopword=True, weighted_tfidf=False):
        ''' Compute TFIDF value, use coordinate matrix. 
        
        Args:
            docs <list of list of str>: an example: [['i', 'love', 'you'], ['who', 'is', 'him', 'you', 'cousin', 'i', 'love'], ['i','hate','you']]
            wrd2id <dict>: a mapping from wrds to its ids.
        Returns:
            tfidf sparse coo matrix in the shape [n_docs, n_wrds]
        '''

        n_doc, n_wrd = len(docs), len(wrd2id)
        if weighted_tfidf: 
            mention_pos = weighted_tfidf
        # allocate memory
        df = np.zeros([n_wrd])

        totalwrds = sum([len(doc) for doc in docs])
        tf_ids = np.zeros([2, totalwrds], dtype=np.int32)
        tf_vals = np.zeros([totalwrds])

        # calculate frequency
        k = 0
        if weighted_tfidf:
            for doc_id, (doc, position) in enumerate(zip(docs, mention_pos)):
                count = Counter(doc)
                weight = self.compute_weight(doc, position)
                for wrd in count:
                    wrd_id = wrd2id[wrd]

                    df[wrd_id] += 1                 # doc freq
                    tf_ids[:, k] = [doc_id, wrd_id] # term freq index
                    tf_vals[k]   = count[wrd] * weight[wrd]       # term freq value

                    k += 1
        else:
            for doc_id, doc in enumerate(docs):
                count = Counter(doc)
                for wrd in count:
                    wrd_id = wrd2id[wrd]

                    df[wrd_id] += 1                 # doc freq
                    tf_ids[:, k] = [doc_id, wrd_id] # term freq index
                    tf_vals[k]   = count[wrd]       # term freq value

                    k += 1

        # inverse doc freq (add-1 smoothing)
        idf = np.log((n_doc + len(df)) / (df + 1))
        if use_stopword:    # if stopword is enabled, do not use corresponding words to calculate tfidf
            stopword_mask = np.ones_like(df)
            stopword_mask[[wrd2id[wrd] for wrd in self.stopwords if wrd in wrd2id]] = 0
            idf *= stopword_mask

        idf = sp.coo_matrix(idf)
        # term freq in coo matrix form
        tf = sp.coo_matrix((tf_vals, tf_ids))
        # compute tfidf
        tfidf = tf.multiply(idf)

        return tfidf


    def compute_weight(self, doc, mention_pos):
        ''' Compute the weight for wrds occurred in the document. Return a dictionary for this.
        '''
        def _dist_to_mention(doc_length, mention_pos):
            dist = []
            for i in range(doc_length):
                if i < mention_pos[0]:
                    dist.append(mention_pos[0] - i)
                elif i >= mention_pos[1]:
                    dist.append(i - mention_pos[1] + 1)
                else:
                    dist.append(0)
            return dist
        dist = _dist_to_mention(len(doc), mention_pos)
        weight = {}
        for l, wrd in zip(dist, doc):
            tmp_weight = 0.5/l if l > 0 else 1.
            weight[wrd] = max(weight.get(wrd, 0), tmp_weight)
        return weight


    def compute_sent_adj(self, docs, wrd2id, threshold=0, use_stopword=True, mask_prev_n=0):
        ''' Compute sparse adjacency matrix for the nodes(sentence + mention). 
        threshold is for filtering out unimportant wrds

        use stopword list to prevent the adjacency matrix to be too dense
        '''
        tfidf = self.computeTFIDF(docs, wrd2id, use_stopword=use_stopword)
        tfidf = tfidf.multiply((tfidf > threshold))
        sent_adj = tfidf.dot(tfidf.transpose())

        if mask_prev_n: # set the elements with index x and y smaller than mask_prev_n to 0
            sent_adj = sent_adj.todok()
            del_key_list = []
            for key in sent_adj.keys():
                if key[0] < mask_prev_n and key[1] < mask_prev_n:
                    del_key_list.append(key)
            for key in del_key_list:
                del sent_adj[key]
            sent_adj = sent_adj.tocoo()
        return sent_adj

    
    def compute_sent_adj_sampling_multiProc(self, docs, wrd2id, threshold=0, use_stopword=True, from_ids=None,\
                     to_ids=None, N=30000, topk=100, n_proc=16, mention_pos=False):
        ''' Compute sentence adjacency by sampling'''
        import multiprocessing as mp

        tfidf = self.computeTFIDF(docs, wrd2id, use_stopword=use_stopword, weighted_tfidf=mention_pos)

        tfidf = tfidf.multiply((tfidf > threshold))

        if not from_ids:
            from_ids = np.arange(len(docs))
        if not to_ids:
            to_ids = np.arange(len(docs))

        def _sub_process(idx, all_data):
            # each starting nodes, sample a group of size N from the to_ids, 
            # and retrieve the top k similar neighbors
            to_candidates = np.random.choice(to_ids, size=N, replace=False)
            # similarity
            sim = tfidf[idx].dot(tfidf[to_candidates].transpose()).tocoo()


            if sim.getnnz() <= topk:
                # get all of them
                top_inds = sim.col
                top_vals = sim.data
            else:
                tmp_topk = np.argpartition(sim.data, -1 * topk)[-1 * topk:]
                top_inds = sim.col[tmp_topk]
                top_vals = sim.data[tmp_topk]

            dest = to_candidates[top_inds]  
            # add oneselfs
            dest = np.concatenate([[idx], dest])
            top_vals = np.concatenate([[1.0], top_vals])

            from_ = np.empty(dest.size)
            from_.fill(idx)

            sub_sent_adj = dict(sp.coo_matrix((top_vals, (list(from_), list(dest))), \
                                shape=(len(docs), len(docs))).todok())
            all_data.update(sub_sent_adj)

        
        all_data = mp.Manager().dict()
        for st in range(0, len(from_ids), n_proc):
            print('Computing sample adjacency: {}/{}'.format(st+1, len(from_ids)), end='\r')
            ed = min(len(from_ids), st+n_proc)

            jobs = []
            for i in range(st, ed):
                idx = from_ids[i]
                job = mp.Process(target=_sub_process, args=(idx, all_data))
                job.start()
                jobs.append(job)
            for job in jobs:
                job.join()
            
        
        sent_adj = sp.dok_matrix((len(from_ids), len(from_ids) + len(to_ids)))
        print('Computing sample adjacency: inserting retrieved data to sparse matrix...')
        for i, key in enumerate(all_data):
            print('Inserting: {}/{}'.format(i, len(all_data)), end='\r')
            sent_adj[key] = all_data[key]
        print('Done.')

        return sent_adj.tocsr()


    def compute_sent_adj_sampling(self, docs, wrd2id, threshold=0, use_stopword=True, from_ids=None,\
                     to_ids=None, N=10000, topk=100):
        ''' Compute sentence adjacency by sampling'''
        tfidf = self.computeTFIDF(docs, wrd2id, use_stopword=use_stopword)
        tfidf = tfidf.multiply((tfidf > threshold))
        sent_adj = sp.dok_matrix((len(docs), len(docs)))

        if not from_ids:
            from_ids = np.arange(len(docs))
        if not to_ids:
            to_ids = np.arange(len(docs))

        for i, idx in enumerate(from_ids):
            if (i+1) % 100 == 0:
                print('Computing sample adjacency: {}/{}'.format(i+1, len(from_ids)), end='\r')
            # each starting nodes, sample a group of size N from the to_ids, 
            # and retrieve the top k similar neighbors
            to_candidates = np.random.choice(to_ids, size=N, replace=False)
            # similarity
            sim = tfidf[idx].dot(tfidf[to_candidates].transpose()).tocoo()

            if sim.getnnz() <= topk:
                # get all of them
                top_inds = sim.col
                top_vals = sim.data
            else:
                tmp_topk = np.argpartition(sim.data, -1* topk)[-1 * topk:]
                top_inds = sim.col[tmp_topk]
                top_vals = sim.data[tmp_topk]
            
            dest = to_candidates[top_inds]

            from_ = np.empty(dest.size)
            from_.fill(idx)

            sub_sent_adj = sp.coo_matrix((top_vals, (list(from_), list(dest))), \
                                shape=(len(docs), len(docs))).todok()
            # update the big dok matrix
            for key in sub_sent_adj.keys():
                sent_adj[key] = sub_sent_adj[key]
        
        return sent_adj.tocsr()

    
    def compute_lbl_coocur(self, lbls, lbl2id):
        ''' Compute the sparse label adjacency matrix based on co-occurrence.
        (No self-edge)

        Args:
            lbls <list of list of str>: labels for each sample.
            lbl2id <dict>: a mapping from label str to its index.
        Returns:
            A coordinate matrix for the coocurrence between labels.
        '''
        n_wrd = len(lbl2id)
        # compute coocurrence
        coocur = sp.dok.dok_matrix((n_wrd, n_wrd))
        for lbls_per_sample in lbls:
            lbl_ids = [lbl2id[lbl] for lbl in lbls_per_sample]

            for x, y in combinations(lbl_ids, 2):
                coocur[x, y] += 1
                coocur[y, x] += 1
        coocur = coocur.tocoo()
        # # self-edge
        # coocur = coocur + sp.eye(n_wrd)
        return coocur   


    def compute_lbl_similarity(self, lbl2id):
        ''' Compute the sparse label adjacency matrix based on word-level cosine-similarity.

        Args:
            lbl2id <dict>: a mapping from label str to its index.
        Returns:
            A coordinate matrix for the coocurrence between labels.
        '''

        pass


    def compute_lbl_adj(self):
        pass

    
    def compute_sent_lbl_adj(self, lbls, lbl2id):
        ''' Compute the sentence - label interaction matrix (sparse)
        '''
        n_doc, n_lbl = len(lbls), len(lbl2id)

        interaction = sp.dok.dok_matrix((n_doc, n_lbl))
        for i, lbls_per_sample in enumerate(lbls):
            for lbl in lbls_per_sample:
                interaction[i, lbl2id[lbl.replace('geograpy', 'geography')]] = 1
        return interaction.tocoo()
        
        
    def get_graph_info(self, read_file, lbl2id_file, save_to_file='', is_train=True, tfidf_threshold=0, weighted_tfidf=True):
        if save_to_file and os.path.exists(save_to_file):
            with open(save_to_file, 'rb') as f:
                graph_info = pickle.load(f)
            return graph_info

        if is_train:
            data, lbl2id, wrd2id = self.preprocess(read_file, saveto='prep_train.pkl', load_saved=True, lbl2id_file=lbl2id_file)
            docs = [item[0] for item in data]
            positions = [item[1] for item in data]
            lbls = [item[2] for item in data]

            # sent_adj = self.compute_sent_adj(docs, wrd2id, use_stopword=True, threshold=tfidf_threshold)
            sent_adj = self.compute_sent_adj_sampling_multiProc(docs, wrd2id, use_stopword=True, threshold =tfidf_threshold, 
                                                    from_ids=range(len(docs)), to_ids=range(len(docs)), 
                                                    mention_pos=weighted_tfidf and positions)
            lbl_adj = self.compute_lbl_coocur(lbls, lbl2id)
            interaction = self.compute_sent_lbl_adj(lbls, lbl2id)

            graph_info = {'lbl2id': lbl2id, 
                        'sent_adj': sent_adj, 
                        'lbl_adj': lbl_adj, 
                        'interaction': interaction}
        else:
            # make test dataset, this time the "read_file" should be a dict with keys "train" and "test"
            assert isinstance(read_file, dict) and "train" in read_file and "test" in read_file
            train_data, lbl2id, wrd2id = self.preprocess(read_file["train"], load_saved=True, saveto='prep_train.pkl', lbl2id_file=lbl2id_file)
            docs = [item[0] for item in train_data]
            positions = [item[1] for item in train_data]
            lbls = [item[2] for item in train_data]

            test_data, test_lbl2id, test_wrd2id = self.preprocess(read_file["test"], load_saved=True, saveto='prep_test.pkl', lbl2id_file=lbl2id_file)
            # diff_set = set(test_lbl2id).difference(set(lbl2id))
            # rev_set=set(lbl2id).difference(test_lbl2id)
            # print(diff_set, len(diff_set),len(rev_set))
            assert all([lbl in lbl2id for lbl in test_lbl2id])
            
            test_docs = [item[0] for item in test_data]
            test_positions = [item[1] for item in test_data]
            test_lbls = [item[2] for item in test_data]

            docs = test_docs + docs
            positions = test_positions + positions
            lbls = test_lbls + lbls

            n_test = len(test_docs)

            # update wrd sets
            n_wrd = len(wrd2id)
            for wrd in test_wrd2id:
                if wrd not in wrd2id:
                    wrd2id[wrd] = n_wrd
                    n_wrd += 1

            # test_sent_adj = self.compute_sent_adj(docs, wrd2id, use_stopword=True, mask_prev_n=n_test, threshold=tfidf_threshold)
            test_sent_adj = self.compute_sent_adj_sampling_multiProc(docs, wrd2id, use_stopword=True, \
                                from_ids=range(n_test), to_ids=range(n_test, len(docs)), threshold=tfidf_threshold,
                                mention_pos=weighted_tfidf and positions)

            test_lbl_adj = self.compute_lbl_coocur(lbls, lbl2id)
            test_interaction = self.compute_sent_lbl_adj(lbls, lbl2id)

            graph_info = {'lbl2id': lbl2id, 'sent_adj': test_sent_adj, 'lbl_adj': test_lbl_adj, 'interaction': test_interaction, 'n_test': n_test}


        if save_to_file:
            with open(save_to_file, 'wb') as f:
                pickle.dump(graph_info, f)
    
        return graph_info



if __name__ == "__main__":
    # read_file_train = 'data/ontonotes/g_train.json'
    # read_file_test = 'data/ontonotes/g_test.json'

    read_file_train = 'data/BBN/g_train.json'
    read_file_test = 'data/BBN/g_test.json'

    
    lbl2id_file = 'data/BBN/hierarchy.txt'

    save_to_file_train = 'tmp/graph_info.pkl'
    save_to_file_test = 'tmp/graph_info_test.pkl'

    proc = preprocessor()
    # preprocess sentences
    # proc.preprocess(read_file_train, saveto='prep_train.pkl')
    # proc.preprocess(read_file_test, saveto='prep_test.pkl')
    
    if not os.path.exists('tmp'): os.mkdir('tmp')

    proc.get_graph_info(read_file = read_file_train, 
                                    lbl2id_file=lbl2id_file, 
                                    save_to_file=save_to_file_train, \
                                    is_train=True)
    graph_info = proc.get_graph_info(read_file = {'train': read_file_train, 'test': read_file_test}, 
                                    lbl2id_file=lbl2id_file, 
                                    save_to_file=save_to_file_test, \
                                    is_train=False)

    print(graph_info['sent_adj'], graph_info['n_test'])
