#!/usr/bin/env python
import copy
import math
import re
import stanza
from collections import deque


class Node:
    def __init__(self, id, text, _type, uri):
        self.text = text
        self.start_pos = 0
        self.end_pos = 0
        self.id = id
        self.type = _type
        self.uri = []
        self.uri.append(uri)
        self.fromNodes = []
        self.toNodes = []


class Edge:
    def __init__(self, fro, to, weight):
        self.fro = fro
        self.to = to
        self.weight = weight


class Graph:
    class parents_table:
        def __init__(self):
            self.parents = {}

    def __init__(self):
        self.nodes = {}
        self.edges = []
        self.distance = {}

    def get_path(self, parents, start, end):
        res = []
        key = copy.deepcopy(end)
        while parents[key] != start:
            res.append(self.nodes[key])
            key = parents[key]
        res.append(self.nodes[key])
        res.append(self.nodes[start])
        res.reverse()

        return res

    def BFS(self, start, end):
        parents_table = self.parents_table()
        parents = parents_table.parents
        for node in self.nodes.keys():
            parents[node] = None
        search_queue = deque()

        distance = {}
        for dic1 in self.distance.keys():
            temp = {}
            for dic2 in self.distance[dic1].keys():
                if self.distance[dic1][dic2] != math.inf and dic1 != dic2:
                    temp[dic2] = self.distance[dic1][dic2]
            distance[dic1] = temp

        for node in distance[start]:
            if not parents[node]:
                parents[node] = start
            search_queue.append(node)
        searched = []
        while search_queue:
            node = search_queue.popleft()
            if node not in searched:
                if node == end:
                    return node, parents
                else:
                    searched.append(node)
                    for n in distance[node]:
                        if not parents[n]:
                            parents[n] = node
                        search_queue.append(n)
        return False


class SubgraphConstructor:
    def __init__(self):
        self.stanza = stanza.Pipeline('en', dir='/DATA_PATH/stanza_resources', use_gpu=True)
        pass

    def subgraphConstruct(self, sentence, nodes):
        sentence = self.stanza(sentence).sentences[0]
        nodes = sorted(nodes, key=lambda x: x['start_pos'], reverse=False)
        new_nodes = []
        cnt = 1
        for node in nodes:  # 1base
            new_nodes.append(
                {'id': cnt, 'chunk': node['chunk'], 'type': node['type'], 'uri': node['uri'], 'pos': 'NOUN',
                 'start_pos': node['start_pos'],
                 'end_pos': node['end_pos']})
            cnt += 1
        depends = []
        for d in sentence.dependencies:
            # print(d)
            if d[0].id == "0" and d[0].text == "ROOT":
                continue
            n1 = d[0]
            n2 = d[2]
            pattern_start = re.compile('start_char=([0-9]+)')
            pattern_end = re.compile('end_char=([0-9]+)')
            n1_start_pos = int(pattern_start.findall(n1.misc)[0])
            n1_end_pos = int(pattern_end.findall(n1.misc)[0])
            n2_start_pos = int(pattern_start.findall(n2.misc)[0])
            n2_end_pos = int(pattern_end.findall(n2.misc)[0])
            if n1.id == "1" and (
                    n1.lemma == "when" or n1.lemma == "who" or n1.lemma == "where" or n1.lemma == "whose" or n1.lemma == "what" or n1.text == "Name" or n1.text == "List"):
                n1.text = n1.lemma
            depends.append(
                ({'id': int(n1.id), 'text': str(n1.text), 'type': '', 'uri': '', 'pos': n1.upos, 'head': int(n1.head),
                  'start_pos': n1_start_pos,
                  'end_pos': n1_end_pos},
                 str(d[1]),
                 {'id': int(n2.id), 'text': str(n2.text), 'type': '', 'uri': '', 'pos': n2.upos, 'head': int(n2.head),
                  'start_pos': n2_start_pos,
                  'end_pos': n2_end_pos}))
        for i in range(len(depends)):
            n1 = depends[i][0]
            n2 = depends[i][2]
            depends[i][0]['id'] = n1['id'] + 1000
            depends[i][0]['head'] = n1['head'] + 1000
            depends[i][2]['id'] = n2['id'] + 1000
            depends[i][2]['head'] = n2['head'] + 1000

        delete_id = []
        for i in range(len(depends)):
            n1 = depends[i][0]
            n2 = depends[i][2]
            flag = False
            for node in new_nodes:
                if n1['start_pos'] >= node['start_pos'] and n1['end_pos'] <= node['end_pos'] and n2['start_pos'] >= \
                        node['start_pos'] and n2['end_pos'] <= node['end_pos']:
                    flag = True
                    start_pos = node['start_pos']
                    end_pos = node['end_pos']
                    name_str = node['chunk']
            if depends[i][1] == 'compound':
                flag = True
                start_pos = min(n1['start_pos'], n2['start_pos'])
                end_pos = max(n1['end_pos'], n2['end_pos'])
                name_str = n1['text'] + ' ' + n2['text'] if start_pos == n1['start_pos'] else n2['text'] + ' ' + n1[
                    'text']
            if flag:

                delete_id.append(i)

                for j in range(len(depends)):
                    if depends[j][0]['head'] == depends[i][2]['id']:
                        depends[j][0]['head'] = depends[i][0]['id']
                    if depends[j][2]['head'] == depends[i][2]['id']:
                        depends[j][2]['head'] = depends[i][0]['id']

                for j in range(len(depends)):
                    if depends[j][0]['id'] == depends[i][0]['id']:
                        depends[j][0]['text'] = name_str
                        depends[j][0]['start_pos'] = start_pos
                        depends[j][0]['end_pos'] = end_pos
                    if depends[j][2]['id'] == depends[i][0]['id']:
                        depends[j][2]['text'] = name_str
                        depends[j][2]['start_pos'] = start_pos
                        depends[j][2]['end_pos'] = end_pos
                    if depends[j][0]['id'] == depends[i][2]['id']:
                        depends[j][0]['id'] == depends[i][0]['id']
                        depends[j][0]['text'] = name_str
                        depends[j][0]['start_pos'] = start_pos
                        depends[j][0]['end_pos'] = end_pos
        _delete = [depends[i] for i in delete_id]
        for _d in _delete:
            depends.remove(_d)

        delete_id.clear()
        _delete.clear()
        for i in range(len(depends)):
            if depends[i][1] == 'case':
                delete_id.append(i)
                _pre = None
                _r = None
                for j in range(len(depends)):
                    if depends[j][2]['id'] == depends[i][0]['id']:
                        delete_id.append(j)
                        _pre = depends[j][0]
                        _r = depends[j][1]
                        break
                for j in range(len(depends)):
                    if depends[j][0]['id'] == depends[i][0]['id']:
                        depends[j][0]['head'] = depends[i][2]['id']
                    if depends[j][2]['id'] == depends[i][0]['id']:
                        depends[j][2]['head'] = depends[i][2]['id']
                if _pre == None:
                    depends[i][2]['head'] = 1000
                else:
                    depends[i][2]['head'] = _pre['id']
                    depends.append((_pre, _r, depends[i][2]))
                depends.append((depends[i][2], 'add-case', depends[i][0]))
        _delete = [depends[i] for i in delete_id]
        for _d in _delete:
            if _d in depends:
                depends.remove(_d)

        for node in new_nodes:
            for i in range(len(depends)):
                n1 = depends[i][0]
                n2 = depends[i][2]
                n2_id = -1
                if n2['start_pos'] >= node['start_pos'] and n2['end_pos'] <= node['end_pos']:
                    n2_id = n2['id']
                    depends[i][2]['id'] = node['id']
                    if n2_id != -1:
                        for j in range(len(depends)):
                            ddn1 = depends[j][0]
                            ddn2 = depends[j][2]
                            if ddn1['head'] == n2_id:
                                depends[j][0]['head'] = node['id']
                            if ddn2['head'] == n2_id:
                                depends[j][2]['head'] = node['id']
                n1_id = -1
                if n1['start_pos'] >= node['start_pos'] and n1['end_pos'] <= node['end_pos']:
                    n1_id = n1['id']
                    depends[i][0]['id'] = node['id']
                    if n1_id != -1:
                        for j in range(len(depends)):
                            ddn1 = depends[j][0]
                            ddn2 = depends[j][2]
                            if ddn1['head'] == n1_id:
                                depends[j][0]['head'] = node['id']
                            if ddn2['head'] == n1_id:
                                depends[j][2]['head'] = node['id']
        for d in depends:
            n1 = d[0]
            n2 = d[2]
            for node in new_nodes:
                if n1['start_pos'] >= node['start_pos'] and n1['end_pos'] <= node['end_pos']:
                    d[0]['type'] = node['type']
                    d[0]['uri'] = node['uri']
                    d[0]['start_pos'] = node['start_pos']
                    d[0]['end_pos'] = node['end_pos']
                if n2['start_pos'] >= node['start_pos'] and n2['end_pos'] <= node['end_pos']:
                    d[2]['type'] = node['type']
                    d[2]['uri'] = node['uri']
                    d[2]['start_pos'] = node['start_pos']
                    d[2]['end_pos'] = node['end_pos']

        graph = Graph()
        for d in depends:

            if d[0]['id'] not in graph.nodes.keys():
                graph.nodes[d[0]['id']] = Node(d[0]['id'], d[0]['text'], d[0]['type'], d[0]['uri'])
                graph.nodes[d[0]['id']].start_pos = d[0]['start_pos']
                graph.nodes[d[0]['id']].end_pos = d[0]['end_pos']
            if d[2]['id'] not in graph.nodes.keys():
                graph.nodes[d[2]['id']] = Node(d[2]['id'], d[2]['text'], d[2]['type'], d[2]['uri'])
                graph.nodes[d[2]['id']].start_pos = d[2]['start_pos']
                graph.nodes[d[2]['id']].end_pos = d[2]['end_pos']
            for _i in graph.nodes:
                graph.distance[_i] = {}
                for _j in graph.nodes:
                    if _i == _j:
                        graph.distance[_i][_j] = 0
                    else:
                        graph.distance[_i][_j] = math.inf
        for d in depends:

            cnt = 0
            if d[0]['id'] < 1000:
                cnt += 1
            if d[2]['id'] < 1000:
                cnt += 1
            if 'poss' in d[1] or 'acl' in d[1]:
                if 'acl' in d[1]:
                    if d[2]['pos'] == 'VERB':
                        graph.distance[d[0]['id']][d[2]['id']] = cnt
                    else:
                        graph.distance[d[2]['id']][d[0]['id']] = cnt
                else:
                    graph.distance[d[2]['id']][d[0]['id']] = cnt
            elif 'ob' in d[1] or 'mod' in d[1]:
                graph.distance[d[0]['id']][d[2]['id']] = cnt
            else:
                graph.distance[d[0]['id']][d[2]['id']] = cnt
                graph.distance[d[2]['id']][d[0]['id']] = cnt
        distance = copy.deepcopy(graph.distance)
        for k in graph.nodes:
            for i in graph.nodes:
                for j in graph.nodes:
                    if distance[i][j] > distance[i][k] + distance[k][j]:
                        distance[i][j] = distance[i][k] + distance[k][j]
        edges = []
        for i in graph.nodes:
            for j in graph.nodes:
                if i < 1000 and j < 1000 and distance[i][j] == 2:
                    edges.append((i, j))
        temp_result = []
        for (head_id, tail_id) in edges:
            item, parentsTable = graph.BFS(head_id, tail_id)
            a = graph.get_path(parentsTable, head_id, tail_id)
            a.remove(graph.nodes[head_id])
            a.remove(graph.nodes[tail_id])
            p = []
            for no in a:
                p.append(no.text)
            path = " ".join(p)
            temp_result.append((head_id, path, a, tail_id))

        new_result = []
        for edge1 in temp_result:
            new_result.append(edge1)
        result = []
        vis_pair = []
        for edge in new_result:
            n1 = None
            if edge[0] != None:
                n1 = graph.nodes[edge[0]]
            n2 = None
            if edge[3] != None:
                n2 = graph.nodes[edge[3]]
            if [n1, n2] in vis_pair or [n2, n1] in vis_pair or n1 == n2:
                continue
            if n1.type == 'class' and n2.type == 'class':
                continue
            vis_pair.append([n1, n2])
            result.append((n1, edge[1], edge[2], n2))
        for edge in result:
            n1 = edge[0].text
            n2 = edge[3].text
            print(n1 + " ---> \'" + edge[1] + "\' ---> " + n2)
        return result


if __name__ == '__main__':
    sc = SubgraphConstructor()
    r = sc.subgraphConstruct(
        "Which company which assembles its cars in Broadmeadows , Victoria ?",
        [{'chunk': 'company', 'type': 'class', 'uri': 'http://dbpedia.org/ontology/Company', 'start_pos': 6,
          'end_pos': 13},
         {'chunk': 'Broadmeadows', 'type': 'entity', 'uri': 'http://dbpedia.org/resource/Broadmeadows,_Victoria',
          'start_pos': 42, 'end_pos': 54}])
    print(r)
