import json
from collections import defaultdict
import re
from tqdm import tqdm
from nltk import sent_tokenize, word_tokenize
import numpy as np
from functools import partial
from multiprocessing import Pool
from os import cpu_count


class KnowledgeGraph:

    def __init__(self, triples):
        print("contruct graph")
        self.data = {}
        self.weights = {}
        for item in triples:
            if not self.eng_word(item[0]) or not self.eng_word(item[-2]):
                continue
            self.add(item[0], item[-2])
            self.add(item[-2], item[0])
            self.weights[self.get_name(item[0], item[-2])] = float(item[-1])
            self.weights[self.get_name(item[-2], item[0])] = float(item[-1])

    def filter_points(self, points):
        res = set()
        for pt in points:
            if pt in self.data:
                res.add(pt)
        return res

    def get_name(self, src, dst):
        return src + "___" + dst

    def get_weight(self, src, dst):
        name = self.get_name(src, dst)
        if name in self.weights:
            return self.weights[name]
        return None

    def eng_word(self, word):
        if '_' in word:
            return False
        return True

    def get_avg_deg(self):
        r = 0
        for src in self.data:
            r += len(self.data[src])

        return r / len(self.data)

    def get_node_num(self):
        return len(self.data)

    def add(self, src, dst):
        if src in self.data:
            if dst not in self.data[src]:
                self.data[src].append(dst)
        else:
            self.data[src] = [dst]

    def get_neighbors(self, pt):
        if pt not in self.data:
            return []
        return self.data[pt]

    def get_hops_set(self, srcs, hop):
        res = set(srcs)
        step = 0
        temp = set(srcs)
        while step < hop:
            step += 1
            new_temp = []
            for pt in temp:
                ns = self.get_neighbors(pt)
                for n in ns:
                    if n not in res:
                        new_temp.append(n)
            new_temp = set(new_temp)
            temp = new_temp
            res = res | new_temp
        return res

    def find_neigh_in_set(self, src, points):
        res = []
        if src not in self.data:
            return res
        for pt in points:
            if pt in self.data[src]:
                res.append(pt)
        return set(res)

    def find_paths(self, srcs, dsts):
        a = self.get_hops_set(srcs, 1)
        res = []
        for w in a:
            x = self.find_neigh_in_set(w, srcs)
            y = self.find_neigh_in_set(w, dsts)
            if x and y:
                res.append([x, w, y])
        return res

    def show_paths(self, srcs, dsts):
        paths = self.find_paths(srcs, dsts)
        for path in paths:
            print(path)

    def get_dis(self, dst, srcs, max_hop=3):
        vis = set()
        points = [dst]
        vis.add(dst)
        step = 0
        if dst in srcs:
            return step
        while step < max_hop:
            step += 1
            temp_points = []
            for pt in points:
                ns = self.get_neighbors(pt)
                for n in ns:
                    if n in srcs:
                        return step
                    if n in vis:
                        continue
                    vis.add(n)
                    temp_points.append(n)
            points = temp_points
        return step


def get_conceptnet():
    with open('../ConceptNet/conceptnet_cleaned_final.txt', encoding='utf-8') as f:
        lines = f.readlines()

    triples = []
    for line in lines:
        triple = line.strip().split('|||')
        triples.append(triple)
    print('graph loaded')
    return KnowledgeGraph(triples)


graph = get_conceptnet()


def one_scene_intersection(scene):
    # print('bedding_kws = {}, type = {}'.format(scene['bedding_kws'], type(scene['bedding_kws'])))
    # print('ending_kws = {}, type = {}'.format(scene['ending_kws'], type(scene['ending_kws'])))

    intersect_total = 0
    ending_total = 0
    bedding_total = 0
    ending_cnt = 0
    bedding_cnt = 0

    if not scene['target_kws'] or not scene['context_kws']:
        scene['intersect_nodes'] = []
        scene['outline_mask'] = [0] * len(scene['bedding_kws'] + scene['ending_kws'])
        # return {'scene':scene, 'bedding':0, 'ending':0, 'total':0}
    else:
        context_kws = set(scene['context_kws'])
        bedding_kws = set(scene['bedding_kws'])
        ending_kws = set(scene['ending_kws'])
        target_kws = set(scene['target_kws'])
        # context_kws = context_kws - target_kws # 去重

        context_nodes = graph.get_hops_set(context_kws, hop=2)  # 两跳的集合
        target_nodes = graph.get_hops_set(target_kws, hop=2)

        intersect_nodes = context_nodes & target_nodes
        kws = bedding_kws | ending_kws
        filter_kws = kws & intersect_nodes

        intersect_total = len(intersect_nodes)
        ending_total = len(scene['ending_kws'])
        bedding_total = len(scene['bedding_kws'])
        ending_cnt = len([i for i in scene['ending_kws'] if i in intersect_nodes])
        bedding_cnt = len([i for i in scene['bedding_kws'] if i in intersect_nodes])
        # scene['filter_bedding_kws'] = list(bedding_kws & intersect_nodes)
        # print('filter_kws = ',filter_kws)
        scene['intersect_nodes'] = list(intersect_nodes)
        scene['outline_mask'] = [1 if i in filter_kws else 0 for i in scene['bedding_kws'] + scene['ending_kws']]
        # print('intersect_nodes = ', scene['intersect_nodes'])

    return {'scene': scene, 'ending_cnt': ending_cnt, 'bedding_cnt': bedding_cnt, 'bedding_total': bedding_total,
            'ending_total': ending_total, 'intersect_total': intersect_total}


def test(data):
    return 1


def add_node(graph, in_file_name, out_file_name):
    with open(in_file_name, 'r', encoding='utf-8') as f:
        data = json.load(f)

    print('file = ', in_file_name)
    print('size = ', len(data))
    # partial_work = partial(one_scene_intersection, graph=graph)

    result = []
    bedding_cnt, ending_cnt, bedding_total, ending_total, intersect_total = 0, 0, 0, 0, 0
    with Pool(cpu_count()) as pool:
        iter = pool.imap(one_scene_intersection, data)
        for i in tqdm(iter):
            # print('i = ', i)
            result.append(i['scene'])
            bedding_cnt += i['bedding_cnt']
            ending_cnt += i['ending_cnt']
            bedding_total += i['bedding_total']
            ending_total += i['ending_total']
            intersect_total += i['intersect_total']

    print('bedding_cnt = {}, total = {}, ratio = {}'.format(bedding_cnt, bedding_total, bedding_cnt / bedding_total))
    print('ending_cnt = {}, total = {}, ratio = {}'.format(ending_cnt, ending_total, ending_cnt / ending_total))
    print('intersect_total = {}, mean = {}'.format(intersect_total, intersect_total / len(data)))

    with open(out_file_name, 'w', encoding='utf-8') as f:
        json.dump(result, f, indent=2, separators=[',', ':'])


for split in ['test', 'valid', 'train']:
    # add_node(graph, f'../data/{split}_add_node_ending_onecard.json', f'../data/{split}_add_node_onecard.json')
    add_node(graph, f'../data/{split}_add_node_onecard_low.json', f'../data/{split}_add_node_onecard_low.json')
    # add_node(graph, f'../data/{split}_add_node_onecard_low.json', f'../data/{split}_add_node_onecard_low_intersect.json')
