import networkx as nx
from networkx import Graph
import itertools


def extract_multi_graphs(data):
    spk_agents = []
    for i in range(len(data)):
        spk_agents.append(data[i].spk_agents)

    adr_agents = []
    for i in range(len(data)):
        adr_agents.append(data[i].adr_agents)

    data_adj_list = []

    for orig_list, dest_list in zip(spk_agents, adr_agents):
        last_diff_user = 16
        last_user = 16
        adj_list = []
        for orig, dest in zip(orig_list, dest_list):
            if last_user != orig:
                last_diff_user = last_user
            if dest == 17:
                dest = last_diff_user
            adj_list.append([orig, dest])
            last_user = orig
        data_adj_list.append(adj_list)

    clean_data_adj_list = []

    for adj in data_adj_list:
        clean_adj = []
        for d in  adj:
            if d[1] != -1 and d[1] != d[0]:
                clean_adj.append(d)
        clean_data_adj_list.append(clean_adj)

    graph_set = []
    for adj in clean_data_adj_list:
        G = nx.MultiDiGraph()
        G.add_edges_from(adj)
        if G.has_node(16):
            G.remove_node(16)

        if G.has_node(17):
            G.remove_node(17)
        graph_set.append(G)

    #for each graph in the graph_set, if not present, add node from id 0 to 3
    for G in graph_set:
        for i in range(4):
            if i not in G.nodes:
                G.add_node(i)

    #given a graph G and a graphlet g, find all instances of g in G (with same number of nodes)
    return graph_set

def extract_graphs(data):
    spk_agents = []
    for i in range(len(data)):
        spk_agents.append(data[i].spk_agents)

    adr_agents = []
    for i in range(len(data)):
        adr_agents.append(data[i].adr_agents)

    data_adj_list = []

    for orig_list, dest_list in zip(spk_agents, adr_agents):
        last_diff_user = 16
        last_user = 16
        adj_list = []
        for orig, dest in zip(orig_list, dest_list):
            if last_user != orig:
                last_diff_user = last_user
            if dest == 17:
                dest = last_diff_user
            adj_list.append([orig, dest])
            last_user = orig
        data_adj_list.append(adj_list)

    clean_data_adj_list = []

    for adj in data_adj_list:
        clean_adj = []
        for d in  adj:
            if d[1] != -1 and d[1] != d[0]:
                clean_adj.append(d)
        clean_data_adj_list.append(clean_adj)

    graph_set = []
    for adj in clean_data_adj_list:
        G = nx.DiGraph()
        G.add_edges_from(adj)
        if G.has_node(16):
            G.remove_node(16)

        if G.has_node(17):
            G.remove_node(17)
        graph_set.append(G)

    #for each graph in the graph_set, if not present, add node from id 0 to 3
    for G in graph_set:
        for i in range(4):
            if i not in G.nodes:
                G.add_node(i)

    indeces_connected_graphs = []
    for i, G in enumerate(graph_set):
        G = G.to_undirected()
        if nx.is_connected(G):
            indeces_connected_graphs.append(i)

    #given a graph G and a graphlet g, find all instances of g in G (with same number of nodes)
    return graph_set, indeces_connected_graphs