#!/usr/bin/env python
import json
from urllib.parse import unquote
import re
from SubgraphConstructor import Node, Edge, Graph

import requests
from SPARQLWrapper import SPARQLWrapper, JSON
from gensim.models import KeyedVectors


def sparql_query(sparql_endpoint, sparql_query_content):
    try:
        sparql_endpoint.setQuery(sparql_query_content)
        sparql_endpoint.setReturnFormat(JSON)
        results = sparql_endpoint.query().convert()
        return results
    except Exception as e:
        print('Error in SPARQL !')
        print(e)


class CandidateGenerator:
    def __init__(self):
        self.sparql_endpoint = SPARQLWrapper("address_of_server")  # DBpedia
        self.relation2id = {}
        with open('/DATA_PATH/ImRL/data/kge/relation2id.txt', 'r') as f:
            data = f.readlines()
            for line in data:
                r = line.replace('\n', '').split('\t')[0]
                id = line.replace('\n', '').split('\t')[1]
                self.relation2id[r] = int(id)
        self.paraphrase = []
        with open('/DATA_PATH/ImRL/data/dictionary/Paraphrase.txt') as f:
            lines = f.readlines()
            for line in lines:
                line = line.replace('\ts ', '\t\'s ').replace(' s ', ' \'s ').replace(' s;', ' \'s;').replace(' s\t',
                                                                                                              ' \'s\t')
                r = line.split('\t')[0]
                localname = self.splitCamelCase(r)
                r_p = "http://dbpedia.org/property/" + r
                p = localname + "\t" + r_p
                if p not in self.paraphrase and r_p in self.relation2id:
                    self.paraphrase.append(p)
                p = line.split('\t')[1].rstrip() + "\t" + r_p
                if p not in self.paraphrase and r_p in self.relation2id:
                    self.paraphrase.append(p)
                r_o = "http://dbpedia.org/property/" + r
                p = localname + "\t" + r_o
                if p not in self.paraphrase and r_o in self.relation2id:
                    self.paraphrase.append(p)
                p = line.split('\t')[1].rstrip() + "\t" + r_o
                if p not in self.paraphrase and r_o in self.relation2id:
                    self.paraphrase.append(p)
        word2vec_model_path = '/DATA_PATH/ImRL/data/glove.6B.50d.txt'
        self.word2vec_model = KeyedVectors.load_word2vec_format(word2vec_model_path, binary=False,
                                                                unicode_errors='ignore')

    def splitCamelCase(self, string):
        string = re.sub('(.)([A-Z][a-z]+)', r'\1 \2', string)
        string = re.sub('(.)([0-9]+)', r'\1 \2', string)
        return re.sub('([a-z0-9])([A-Z])', r'\1 \2', string).lower()

    def candidataGenerate(self, triples):
        topk = 1
        ans = []
        for triple in triples:
            n1 = triple[0]
            r = triple[2]
            n2 = triple[3]
            candidate_n1s = []
            candidate_n2s = []
            if n1 != None:
                candidate_n1s = n1.uri
            if n2 != None:
                candidate_n2s = n2.uri
            print(candidate_n1s, candidate_n2s)
            _1hop = []
            _2hop = []
            _3hop = []
            if n1 != None and n2 != None:
                if n1.type == 'entity' and n2.type == 'entity':
                    cnt = 0
                    for candidate_n1 in candidate_n1s:
                        for candidate_n2 in candidate_n2s:
                            # 1: n1 -> n2
                            sparql_query_content = "Select distinct ?p where { <" + candidate_n1 + "> ?p <" + candidate_n2 + "> . }"
                            res = sparql_query(self.sparql_endpoint, sparql_query_content)
                            for result in res['results']['bindings']:
                                if 'http://www.w3.org/' not in result['p']['value'] and 'http://xmlns.com' not in \
                                        result['p']['value']:
                                    _1hop.append((result['p']['value'], 1))
                                    cnt += 1
                            # 2: n1 <- n2
                            sparql_query_content = "Select distinct ?p where { <" + candidate_n2 + "> ?p <" + candidate_n1 + "> .}"
                            res = sparql_query(self.sparql_endpoint, sparql_query_content)
                            for result in res['results']['bindings']:
                                if 'http://www.w3.org/' not in result['p']['value'] and 'http://xmlns.com' not in \
                                        result['p']['value']:
                                    _1hop.append((result['p']['value'], 2))
                                    cnt += 1
                            # 3: n1 -> x -> n2
                            sparql_query_content = "Select distinct ?p1 ?p2 where { <" + candidate_n1 + "> ?p1 ?x . ?x ?p2 <" + candidate_n2 + "> .}"
                            res = sparql_query(self.sparql_endpoint, sparql_query_content)
                            for result in res['results']['bindings']:
                                if 'http://www.w3.org/' not in result['p1']['value'] and 'http://xmlns.com' not in \
                                        result['p1']['value'] \
                                        and 'http://www.w3.org/' not in result['p2'][
                                    'value'] and 'http://xmlns.com' not in result['p2']['value']:
                                    _2hop.append((result['p1']['value'], result['p2']['value'], 3))
                                    cnt += 1
                            # 4: n1 -> x <- n2
                            sparql_query_content = "Select distinct ?p1 ?p2 where { <" + candidate_n1 + "> ?p1 ?x . <" + candidate_n2 + "> ?p2 ?x .}"
                            res = sparql_query(self.sparql_endpoint, sparql_query_content)
                            for result in res['results']['bindings']:
                                if 'http://www.w3.org/' not in result['p1']['value'] and 'http://xmlns.com' not in \
                                        result['p1']['value'] \
                                        and 'http://www.w3.org/' not in result['p2'][
                                    'value'] and 'http://xmlns.com' not in result['p2']['value']:
                                    _2hop.append((result['p1']['value'], result['p2']['value'], 4))
                                    cnt += 1
                            # 5: n1 <- x <- n2
                            sparql_query_content = "Select distinct ?p1 ?p2 where { ?x ?p1 <" + candidate_n1 + "> . <" + candidate_n2 + "> ?p2 ?x .}"
                            res = sparql_query(self.sparql_endpoint, sparql_query_content)
                            for result in res['results']['bindings']:
                                if 'http://www.w3.org/' not in result['p1']['value'] and 'http://xmlns.com' not in \
                                        result['p1']['value'] \
                                        and 'http://www.w3.org/' not in result['p2'][
                                    'value'] and 'http://xmlns.com' not in result['p2']['value']:
                                    _2hop.append((result['p1']['value'], result['p2']['value'], 5))
                                    cnt += 1
                            # 6: n1 <- x -> n2
                            sparql_query_content = "Select distinct ?p1 ?p2 where { ?x ?p1 <" + candidate_n1 + "> . ?x ?p2 <" + candidate_n2 + "> .}"
                            res = sparql_query(self.sparql_endpoint, sparql_query_content)
                            for result in res['results']['bindings']:
                                if 'http://www.w3.org/' not in result['p1']['value'] and 'http://xmlns.com' not in \
                                        result['p1']['value'] \
                                        and 'http://www.w3.org/' not in result['p2'][
                                    'value'] and 'http://xmlns.com' not in result['p2']['value']:
                                    _2hop.append((result['p1']['value'], result['p2']['value'], 6))
                                    cnt += 1
                    if cnt == 0:
                        r_mention = triple[1]
                        if r_mention == '':
                            _1hop.append(('http://www.w3.org/1999/02/22-rdf-syntax-ns#type', 1))
                        result = []
                        for _key in self.paraphrase:
                            pattern = _key.split('\t')[0]
                            relation = _key.split('\t')[1]
                            score = self.word2vec_model.wmdistance(r_mention.split(), pattern.split())
                            result.append((pattern, relation, score))
                        result = sorted(result, key=lambda x: x[1], reverse=True)
                        for i, item in enumerate(result):
                            if i < 50:
                                _1hop.append((item[1], 1))
                elif n1.type == 'entity' and n2.type == 'class':
                    for candidate_n1 in candidate_n1s:
                        for candidate_n2 in candidate_n2s:
                            # 1: n1 -> n2
                            sparql_query_content = "Select distinct ?p where  { <" + candidate_n1 + "> ?p ?x . }"
                            res = sparql_query(self.sparql_endpoint, sparql_query_content)
                            for result in res['results']['bindings']:
                                if 'http://xmlns.com' not in result['p']['value']:
                                    _1hop.append((result['p']['value'], 1))
                            # 2: n1 <- n2
                            sparql_query_content = "Select distinct ?p where  { ?x ?p <" + candidate_n1 + ">. }"
                            res = sparql_query(self.sparql_endpoint, sparql_query_content)
                            for result in res['results']['bindings']:
                                if 'http://xmlns.com' not in result['p']['value']:
                                    _1hop.append((result['p']['value'], 2))
                            # 3: n1 -> x -> n2
                            sparql_query_content = "Select distinct ?p1 ?p2 where { <" + candidate_n1 + "> ?p1 ?x . ?x ?p2 ?y .}"
                            res = sparql_query(self.sparql_endpoint, sparql_query_content)
                            for result in res['results']['bindings']:
                                if 'http://www.w3.org/' not in result['p1']['value'] and 'http://xmlns.com' not in \
                                        result['p1']['value'] \
                                        and 'http://www.w3.org/' not in result['p2'][
                                    'value'] and 'http://xmlns.com' not in result['p2']['value']:
                                    _2hop.append((result['p1']['value'], result['p2']['value'], 3))
                            # 4: n1 -> x <- n2
                            sparql_query_content = "Select distinct ?p1 ?p2 where { <" + candidate_n1 + "> ?p1 ?x . ?y ?p2 ?x.}"
                            res = sparql_query(self.sparql_endpoint, sparql_query_content)
                            for result in res['results']['bindings']:
                                if 'http://www.w3.org/' not in result['p1']['value'] and 'http://xmlns.com' not in \
                                        result['p1']['value'] \
                                        and 'http://www.w3.org/' not in result['p2'][
                                    'value'] and 'http://xmlns.com' not in result['p2']['value']:
                                    _2hop.append((result['p1']['value'], result['p2']['value'], 4))
                            # 5: n1 <- x <- n2
                            sparql_query_content = "Select distinct ?p1 ?p2 where { ?x ?p1 <" + candidate_n1 + "> . ?y ?p2 ?x .}"
                            res = sparql_query(self.sparql_endpoint, sparql_query_content)
                            for result in res['results']['bindings']:
                                if 'http://www.w3.org/' not in result['p1']['value'] and 'http://xmlns.com' not in \
                                        result['p1']['value'] \
                                        and 'http://www.w3.org/' not in result['p2'][
                                    'value'] and 'http://xmlns.com' not in result['p2']['value']:
                                    _2hop.append((result['p1']['value'], result['p2']['value'], 5))
                            # 6: n1 <- x -> n2
                            sparql_query_content = "Select distinct ?p1 ?p2 where { ?x ?p1 <" + candidate_n1 + "> . ?x ?p2 ?y .}"
                            res = sparql_query(self.sparql_endpoint, sparql_query_content)
                            for result in res['results']['bindings']:
                                if 'http://www.w3.org/' not in result['p1']['value'] and 'http://xmlns.com' not in \
                                        result['p1']['value'] \
                                        and 'http://www.w3.org/' not in result['p2'][
                                    'value'] and 'http://xmlns.com' not in result['p2']['value']:
                                    _2hop.append((result['p1']['value'], result['p2']['value'], 6))
                elif n1.type == 'class' and n2.type == 'entity':
                    for candidate_n1 in candidate_n1s:
                        for candidate_n2 in candidate_n2s:
                            # 1: n1 -> n2
                            sparql_query_content = "Select distinct ?p where { ?x ?p <" + candidate_n2 + "> .}"
                            res = sparql_query(self.sparql_endpoint, sparql_query_content)
                            for result in res['results']['bindings']:
                                if 'http://xmlns.com' not in result['p']['value']:
                                    _1hop.append((result['p']['value'], 1))
                            # 2: n1 <- n2
                            sparql_query_content = "Select distinct ?p where { <" + candidate_n2 + "> ?p ?x.}"
                            res = sparql_query(self.sparql_endpoint, sparql_query_content)
                            for result in res['results']['bindings']:
                                if 'http://xmlns.com' not in result['p']['value']:
                                    _1hop.append((result['p']['value'], 2))
                            # 3: n1 -> x -> n2
                            sparql_query_content = "Select distinct ?p1 ?p2 where { ?y ?p1 ?x . ?x ?p2 <" + candidate_n2 + ">.}"
                            res = sparql_query(self.sparql_endpoint, sparql_query_content)
                            for result in res['results']['bindings']:
                                if 'http://www.w3.org/' not in result['p1']['value'] and 'http://xmlns.com' not in \
                                        result['p1']['value'] \
                                        and 'http://www.w3.org/' not in result['p2'][
                                    'value'] and 'http://xmlns.com' not in result['p2']['value']:
                                    _2hop.append((result['p1']['value'], result['p2']['value'], 3))
                            # 4: n1 -> x <- n2
                            sparql_query_content = "Select distinct ?p1 ?p2 where { ?x ?p1 ?x . <" + candidate_n2 + "> ?p2 ?x.}"
                            res = sparql_query(self.sparql_endpoint, sparql_query_content)
                            for result in res['results']['bindings']:
                                if 'http://www.w3.org/' not in result['p1']['value'] and 'http://xmlns.com' not in \
                                        result['p1']['value'] \
                                        and 'http://www.w3.org/' not in result['p2'][
                                    'value'] and 'http://xmlns.com' not in result['p2']['value']:
                                    _2hop.append((result['p1']['value'], result['p2']['value'], 4))
                            # 5: n1 <- x <- n2
                            sparql_query_content = "Select distinct ?p1 ?p2 where { ?x ?p1 ?y . <" + candidate_n2 + "> ?p2 ?x .}"
                            res = sparql_query(self.sparql_endpoint, sparql_query_content)
                            for result in res['results']['bindings']:
                                if 'http://www.w3.org/' not in result['p1']['value'] and 'http://xmlns.com' not in \
                                        result['p1']['value'] \
                                        and 'http://www.w3.org/' not in result['p2'][
                                    'value'] and 'http://xmlns.com' not in result['p2']['value']:
                                    _2hop.append((result['p1']['value'], result['p2']['value'], 5))
                            # 6: n1 <- x -> n2
                            sparql_query_content = "Select distinct ?p1 ?p2 where { ?x ?p1 ?y . ?x ?p2 <" + candidate_n2 + "> .}"
                            res = sparql_query(self.sparql_endpoint, sparql_query_content)
                            for result in res['results']['bindings']:
                                if 'http://www.w3.org/' not in result['p1']['value'] and 'http://xmlns.com' not in \
                                        result['p1']['value'] \
                                        and 'http://www.w3.org/' not in result['p2'][
                                    'value'] and 'http://xmlns.com' not in result['p2']['value']:
                                    _2hop.append((result['p1']['value'], result['p2']['value'], 6))
                elif n1.type == 'class' and n2.type == 'class':
                    r_mention = triple[1]
                    if r_mention == '':
                        _1hop.append(('http://www.w3.org/1999/02/22-rdf-syntax-ns#type', 1))
                    result = []
                    for _key in self.paraphrase:
                        pattern = _key.split('\t')[0]
                        relation = _key.split('\t')[1]
                        score = self.word2vec_model.wmdistance(r_mention.split(), pattern.split())
                        result.append((pattern, relation, score))
                    result = sorted(result, key=lambda x: x[1], reverse=True)
                    for i, item in enumerate(result):
                        if i < 50:
                            _1hop.append((item[1], 1))
            ans.append({'relation': r, 'candidate_path': {'1hop': list(set(_1hop)), '2hop': list(set(_2hop)),
                                                          '3hop': list(set(_3hop))}})
        return ans


if __name__ == '__main__':
    cg = CandidateGenerator()
    n1 = Node(1, 'company', 'class', 'http://dbpedia.org/ontology/Company')
    n2 = Node(2, 'Broadmeadows', 'entity', 'http://dbpedia.org/resource/Broadmeadows,_Victoria')
    res = cg.candidataGenerate([[n1, 'assembles in', [], n2]])
    print(res)
