class KnowledgeGraph:

    def __init__(self, edges, bidir=False):
        self.data = {}
        self.weights = {}
        self.relations = set()
        for item in edges:
            # [head, relation, tail, weight]
            head = item[0]
            relation = item[1]
            tail = item[2]
            if not self.eng_word(head) or not self.eng_word(tail):
                continue
            assert '/' not in head and '/' not in tail
            self.add(head, relation, tail)
            self.relations.add(relation)
            if bidir:
                self.add(tail, relation, head)
            # for item in triples:
            #     if not self.eng_word(item[0]) or not self.eng_word(item[-2]):
            #         continue
            # self.add(item[0], item[-2])
            # self.add(item[-2], item[0])
            self.weights[self.get_name(item[0], item[-2])] = float(item[-1])
            if bidir:
                self.weights[self.get_name(item[-2], item[0])] = float(item[-1])

        print(f"relation nums:{len(self.relations)}")

    def filter_points(self, points):
        res = []
        for pt in points:
            if pt in self.data:
                res.append(pt)
        return res

    def check(self, point):
        return point in self.data

    def get_name(self, src, dst):
        return src + "___" + dst

    def get_weight(self, src, dst):
        name = self.get_name(src, dst)
        if name in self.weights:
            return self.weights[name]
        return None

    def eng_word(self, word):
        if '_' in word:
            return False
        return True

    def get_avg_deg(self):
        r = 0
        for src in self.data:
            r += len(self.data[src])

        return r / len(self.data)

    def show_degs(self):
        data = list(self.data.items())
        print(data[-3:])
        data.sort(key=lambda x: len(x[1]))
        for k, v in data:
            print(f'{k}:{len(v)}')

    def get_node_num(self):
        return len(self.data)

    def add(self, src, relation, dst):
        #dst = dst + relation # 暂时去掉关系
        if src in self.data:
            if dst not in self.data[src]:
                self.data[src].append(dst)
        else:
            self.data[src] = [dst]

    def get_neighbors(self, pt):
        if pt not in self.data:
            return []
        return self.data[pt]

    def get_hops_set(self, srcs, hop):
        res = set(srcs)
        step = 0
        temp = set(srcs)
        while step < hop:
            step += 1
            new_temp = []
            for pt in temp:
                ns = self.get_neighbors(pt)
                for n in ns:
                    if n not in res:
                        new_temp.append(n)
            new_temp = set(new_temp)
            temp = new_temp
            res = res | new_temp
        return res

    def get_intersect(self, srcs, dsts, hop=2):
        src_neis = self.get_hops_set(srcs, hop)
        dst_neis = self.get_hops_set(dsts, hop)
        return src_neis & dst_neis

    def find_neigh_in_set(self, src, points):
        res = []
        if src not in self.data:
            return res
        for pt in points:
            if pt in self.data[src]:
                res.append(pt)
        return set(res)

    def find_paths(self, srcs, dsts):
        a = self.get_hops_set(srcs, 1)
        res = []
        for w in a:
            x = self.find_neigh_in_set(w, srcs)
            y = self.find_neigh_in_set(w, dsts)
            if x and y:
                res.append([x, w, y])
        return res

    def show_paths(self, srcs, dsts):
        paths = self.find_paths(srcs, dsts)
        for path in paths:
            print(path)

    def get_dis(self, dst, srcs, max_hop=3):
        vis = set()
        points = [dst]
        vis.add(dst)
        step = 0
        if dst in srcs:
            return step
        while step < max_hop:
            step += 1
            temp_points = []
            for pt in points:
                ns = self.get_neighbors(pt)
                for n in ns:
                    if n in srcs:
                        return step
                    if n in vis:
                        continue
                    vis.add(n)
                    temp_points.append(n)
            points = temp_points
        return step


def get_conceptnet():
    # import csv
    def get_father_dir():
        import os
        # print(__file__)
        # print(os.path.dirname(__file__))
        return os.path.abspath(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

    with open(f'{get_father_dir()}/ConceptNet/conceptnet_cleaned_final.txt', encoding='utf-8') as f:
        lines = f.readlines()
        print(len(lines))
        edges = []
        for line in lines:
            edge = line.strip().split('|||')
            edges.append(edge)

        return KnowledgeGraph(edges)


if __name__ == '__main__':
    graph = get_conceptnet()
    print(f"node num:{graph.get_node_num()}, avg deg:{graph.get_avg_deg()}")
    # graph.show_degs()
    # print(graph.get_hops_set(['people'], 1))
