import json
import operator
from functools import reduce

class KnowledgeGraph:

    def __init__(self, triples):
        print("contruct graph")
        self.data = {}
        self.weights = {}
        for item in triples:
            if not self.eng_word(item[0]) or not self.eng_word(item[-2]):
                continue
            self.add(item[0], item[-2])
            self.add(item[-2], item[0])
            self.weights[self.get_name(item[0], item[-2])] = float(item[-1])
            self.weights[self.get_name(item[-2], item[0])] = float(item[-1])

    def filter_points(self, points):
        res = set()
        for pt in points:
            if pt in self.data:
                res.add(pt)
        return res

    def get_name(self, src, dst):
        return src + "___" + dst

    def get_weight(self, src, dst):
        name = self.get_name(src, dst)
        if name in self.weights:
            return self.weights[name]
        return None

    def eng_word(self, word):
        if '_' in word:
            return False
        return True

    def get_avg_deg(self):
        r = 0
        for src in self.data:
            r += len(self.data[src])

        return r / len(self.data)

    def get_node_num(self):
        return len(self.data)

    def add(self, src, dst):
        if src in self.data:
            if dst not in self.data[src]:
                self.data[src].append(dst)
        else:
            self.data[src] = [dst]

    def get_neighbors(self, pt):
        if pt not in self.data:
            return []
        return self.data[pt]

    def get_hops_set(self, srcs, hop):
        res = set(srcs)
        step = 0
        temp = set(srcs)
        while step < hop:
            step += 1
            new_temp = []
            for pt in temp:
                ns = self.get_neighbors(pt)
                for n in ns:
                    if n not in res:
                        new_temp.append(n)
            new_temp = set(new_temp)
            temp = new_temp
            res = res | new_temp
        return res


    def find_neigh_in_set(self, src, points):
        res = []
        if src not in self.data:
            return res
        for pt in points:
            if pt in self.data[src]:
                res.append(pt)
        return set(res)

    def find_paths(self, srcs, dsts):
        a = self.get_hops_set(srcs, 1)
        res = []
        for w in a:
            x = self.find_neigh_in_set(w, srcs)
            y = self.find_neigh_in_set(w, dsts)
            if x and y:
                res.append([x, w, y])
        return res

    def show_paths(self, srcs, dsts):
        paths = self.find_paths(srcs, dsts)
        for path in paths:
            print(path)

    def get_dis(self, dst, srcs, max_hop=3):
        vis = set()
        points = [dst]
        vis.add(dst)
        step = 0
        if dst in srcs:
            return step
        while step < max_hop:
            step += 1
            temp_points = []
            for pt in points:
                ns = self.get_neighbors(pt)
                for n in ns:
                    if n in srcs:
                        return step
                    if n in vis:
                        continue
                    vis.add(n)
                    temp_points.append(n)
            points = temp_points
        return step


def get_conceptnet():
    with open('../ConceptNet/conceptnet_cleaned_final.txt', encoding='utf-8') as f:
        lines = f.readlines()

    triples = []
    for line in lines:
        triple = line.strip().split('|||')
        triples.append(triple)
    print('graph loaded')
    return KnowledgeGraph(triples)


def extract_outline(text):
    words = text.split()
    outline = []
    for i, word in enumerate(words):
        if word == '<|endoftarget|>':
            accum = []
            for j,word in enumerate(words[i+1:]):
                if word == '<|sepofoutline|>':
                    outline.append(' '.join(accum))
                    accum = []
                elif word == '<|beginofbedding|>':
                    return outline
                else:
                    accum.append(word)
    return outline


def outline_graph_check(data):
    outlines = [extract_outline(i['generated']) for i in data]
    outlines = set(reduce(operator.add, outlines))
    graph = get_conceptnet()
    filter_outlines = graph.filter_points(outlines)
    print('outlines ori len = ', len(outlines))
    print('filte len = ', len(filter_outlines))
    result = list(set(outlines) - set(filter_outlines)) # 不在图谱中的
    with open("outline_check_result.json", 'w', encoding='utf-8') as f:
        json.dump(list(result), f, indent=4, separators=[',', ':'])

def outline_intersect_check(generate_data, label):
    outlines = [extract_outline(i['generated']) for i in data]
    print('len = ', len(reduce(operator.add, outlines)))
    intersects = [i['intersect_nodes'] for i in label]
    t = 0
    p = 0
    for outline, intersect in zip(outlines, intersects):
        for i in outline:
            if i not in intersect:
                p += 1
            else:
                t += 1
    print('t = {}, p = {}'.format(t, p))


if __name__ == "__main__":

    with open("../result/gpt2_baseline_explicit_outline_onecard_truetarget.json", 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    with open("../data/test_add_node.json", 'r', encoding='utf-8') as f:
        label = json.load(f)

    outline_intersect_check(data, label)
