import spacy
from spacy.tokens import Doc
from spacy.language import Language

import numpy as np
import json
import pickle
import time

# Custom spacy tokenizer: split by blank
class WhitespaceTokenizer:
    def __init__(self, vocab):
        self.vocab = vocab

    def __call__(self, text):
        words = text.strip().split(" ")
        return Doc(self.vocab, words=words)

# Explicitly set sentence start, to tell the parser this is a complete sentence
@Language.component("custom_sentencizer")
def custom_sentencizer(doc):
    doc[0].is_sent_start = True
    for token in doc[1:]:
        token.is_sent_start = False
    return doc

# Transfer edge messages in the tree to adj matrix
def tree_to_adj(doc, root, nodes, adj):
    queue = [root]
    for node in queue:
        for child in doc[node].children:
            if child.i in nodes:
                adj[node][child.i] = 1
                queue.append(child.i)

    return adj

def get_pruned_tree(doc, sent_len, entities, prune=-1):
    INF = 1e6
    common_ancestors = None
    entity_ancestors = set()
    for entity in entities:
        for i in entity:
            ancestors = [] # Record all ancestors of this token
            h = i
            while doc[h].dep_ != "ROOT":
                ancestors.append(h)
                h = doc[h].head.i
            ancestors.append(h) # Add the root

            entity_ancestors.update(ancestors)
            if common_ancestors == None:
                common_ancestors = set(ancestors)
            else:
                common_ancestors.intersection_update(ancestors)

    # Find lca: no child node in common ancestors
    assert len(common_ancestors) > 0, "Entities must have common ancestors."
    lca = -1
    for idx in common_ancestors:
        tag = 0 # indicates no child in this set
        for child in doc[idx].children:
            if child.i in common_ancestors:
                tag = 1
                break
        if tag == 0:
            lca = idx
            break
    
    assert lca >= 0, "lca computes in a wrong way."

    # Record shortest dependency path nodes
    path_nodes = entity_ancestors.difference(common_ancestors)
    path_nodes.add(lca)

    # Compute all distance to path_nodes
    dist = [-1 for i in range(sent_len)]
    for i in range(sent_len):
        if dist[i] < 0: # not compute yet
            if i in path_nodes:
                dist[i] = 0
            else:
                stack = []
                h = i
                while h not in path_nodes:
                    stack.append(h)
                    if doc[h].dep_ == "ROOT":
                        break
                    h = doc[h].head.i

                if h in path_nodes:
                    for d,idx in enumerate(reversed(stack)):
                        dist[idx] = d + 1
                else:
                    for idx in stack:
                        dist[idx] = INF # not in the lca subtree

    # Prune irrelevant nodes
    nodes = [i for i in range(sent_len) if dist[i] <= prune]
    root = lca

    return root, nodes    


# prune=-1: no pruning, prune=0: shortest dependency path (SDP), prune>0: SDP+${prune}-away nodes 
# method=1: based on triples; method=2, based on entities
def input_to_adj(a_line, doc, prune=-1, method=2, directed=True, self_loop=True):
    sent_num = 0
    for sent in doc.sents:
        sent_num += 1
    assert sent_num == 1, "A piece of data is a complete sentence."

    sent_len = len(doc)    
    adj = np.zeros((sent_len, sent_len), dtype=np.float64)
    
    entities = {}
    entityMentions = a_line['entityMentions']
    for entity in entityMentions:
        entities[entity['text']] = entity['offset']
    
    if prune < 0 or len(entityMentions) == 0: # no pruning
        for token in doc:
            for child in token.children:
                adj[token.i][child.i] = 1 
    else:
        # prune based on triples 
        if method == 1:
            relationMentions = a_line['relationMentions']
            for triple in relationMentions:
                em1_pos = entities[triple['em1Text']]
                em2_pos = entities[triple['em2Text']]
                subj = [i for i in range(em1_pos[0], em1_pos[1])]
                obj = [i for i in range(em2_pos[0], em2_pos[1])]
                entities_pos = []
                entities_pos.append(subj)
                entities_pos.append(obj)
                
                root, nodes = get_pruned_tree(doc, sent_len, entities_pos, prune)
                adj = tree_to_adj(doc, root, nodes, adj)
        
        # prune based on entities
        elif method == 2:
            entities_pos = []
            for entity_pos in entities.values():
                offset = [i for i in range(entity_pos[0], entity_pos[1])]
                entities_pos.append(offset)
                root, nodes = get_pruned_tree(doc, sent_len, entities_pos, prune)
                adj = tree_to_adj(doc, root, nodes, adj)
       
    if self_loop:
        for i in range(sent_len):
            adj[i][i] = 1

    if directed:
        adj_fw = adj.tolist()
    else:
        adj = adj + adj.T
        adj_fw = adj.tolist()
    
    return adj_fw

def get_dep_adj_file(in_file, out_file, prune=-1, method=2):
    nlp = spacy.load("en_core_web_lg")
    nlp.tokenizer = WhitespaceTokenizer(nlp.vocab)
    nlp.add_pipe("custom_sentencizer", before="parser")

    out = []
    with open(in_file, "rt") as f:
        sent_id = 0
        for line in f:
            print("Reading line[{0}]......".format(sent_id))
            a_line = json.loads(line.strip())
            doc = nlp(a_line['sentText'])
            if prune < 0:
                adj_fw = get_full_dep_tree(doc)
            else:
                adj_fw = input_to_adj(a_line, doc, prune, method)
            
            a_data = {}
            a_data['id'] = sent_id
            a_data['adj_fw'] = adj_fw
            out.append(a_data)

            sent_id += 1
    
    with open(out_file, "wb") as f:
        pickle.dump(out, f)

def get_full_dep_tree(doc):
    sent_len = len(doc)
    adj = [[0] * sent_len for _ in range(sent_len)]
    for token in doc:
        adj[token.i][token.i] = 1
        for child in token.children:
            adj[token.i][child.i] = 1

    return adj

if __name__ == '__main__':
    time_start = time.time()

    dataset = "../ACE05"
    in_file = dataset + "/data/train.json"
    dep_file = dataset + "/adj/dep/train.adj"
    gp_dep_file = dataset + "/adj/gp-dep/train.adj"

    # get dep graph
    get_dep_adj_file(in_file, dep_file, prune=-1)
    
    # get gp-dep graph
    get_dep_adj_file(in_file, gp_dep_file, prune=1, method=2)
