import os
from collections import OrderedDict
import math
import numpy as np
import scipy.sparse as sp
import torch
from texttable import Texttable
import pandas as pd
import swifter
import json
import pandas as pd
from typing import List
from PIL import Image
from collections import namedtuple
import leidenalg
import igraph as ig
import networkx as nx
from tqdm import tqdm
from scipy.spatial import Delaunay
from scipy.spatial import Voronoi
from GPT_GNN.config import *

tqdm.pandas()
swifter.set_defaults(
    npartitions=None,
    dask_threshold=1,
    scheduler="processes",
    progress_bar=False,
    progress_bar_desc=None,
    allow_dask_on_strings=False,
    force_parallel=False,
)
optimiser = leidenalg.Optimiser()
optimiser.set_rng_seed(37)
Word = namedtuple('Word', ['text', 'left', 'top', 'right', 'bottom', 'block', 'segment', 'label'])


def randint():
    return np.random.randint(2**32 - 1)


def feature_vrdu(layer_data, graph):
    feature = {}
    mask    = {}
    vt_idx  = {}
    times   = {}
    indxs   = {}

    idxs  = np.array(list(layer_data.keys()))
    
    feature     = np.array(list(graph.node_feature['emb']), dtype=np.float)
    mask        = np.array(list(graph.node_mask['emb']), dtype=np.int)
    vt_idx      = np.array(graph.node_vt_idx['emb'], dtype=np.int) if use_visual_embeds else np.array([])
    indxs       = idxs
    attr        = feature
    return feature, mask, vt_idx, indxs, attr


def load_gnn(_dict):
    out_dict = {}
    for key in _dict:
        if 'gnn' in key:
            out_dict[key[4:]] = _dict[key]
    return OrderedDict(out_dict)


def load_gnn(_dict):
    out_dict = {}
    for key in _dict:
        if 'gnn' in key:
            out_dict[key[4:]] = _dict[key]
    return OrderedDict(out_dict)


def get_pct_offsets(words: List[Word], imagefile: str, image_width: float = 0, image_height: float = 0, iftraining: bool = True):
    """
    :param filename: name of a json file
    :param filebuffer: josin file handle
    :param imagefile: name of the image file
    :return: dataframe representing each token in the file, with absolute and relative offsets
    """

    dictionary = {'file': [], 'page_width': [], 'page_height': [],
                    'left': [], 'top': [], 'right': [], 'bottom': [], 'center': [], 'middle': [],
                    'left_pct': [], 'top_pct': [], 'right_pct': [], 'bottom_pct': [], 'center_pct': [], 'middle_pct': [],
                    'block': [], 'segment': [], 'span': []}
    if iftraining: dictionary['tag'] = []

    if image_width > 0 and image_height > 0:
        w, h = image_width, image_height
    else:
        im = Image.open(imagefile)
        w, h = im.size
    # data = json.load(filebuffer)
    # for item in data['form']:
    #     if iftraining:
    #         tag = item['label']
    #     for word in item['words']:
    #         if 'text' not in word:
    #             continue
    #         left, top, right, bottom = word['box']
    for word in words:
        text, left, top, right, bottom, block, segment = word.text, word.left, word.top, word.right, word.bottom, word.block, word.segment
        center, middle = (left+right)/2, (top+bottom)/2
        left_pct, top_pct, right_pct, bottom_pct, center_pct, middle_pct = left / w, top / h, right / w, bottom / h, center / w, middle / h
        dictionary['file'].append(imagefile)
        dictionary['page_width'].append(w)
        dictionary['page_height'].append(h)
        dictionary['left'].append(left)
        dictionary['top'].append(top)
        dictionary['right'].append(right)
        dictionary['bottom'].append(bottom)
        dictionary['center'].append(center)
        dictionary['middle'].append(middle)
        dictionary['left_pct'].append(left_pct)
        dictionary['top_pct'].append(top_pct)
        dictionary['right_pct'].append(right_pct)
        dictionary['bottom_pct'].append(bottom_pct)
        dictionary['center_pct'].append(center_pct)
        dictionary['middle_pct'].append(middle_pct)
        dictionary['block'].append(block)
        dictionary['segment'].append(segment)
        dictionary['span'].append(text)
        if iftraining: 
            dictionary['tag'].append(word.label)
    df = pd.DataFrame.from_dict(dictionary)
    g = df.groupby(['block'])
    segment_left = g['left_pct'].min().reset_index().rename(columns={'left_pct': 'segment_left_pct'})
    segment_top = g['top_pct'].min().reset_index().rename(columns={'top_pct': 'segment_top_pct'})
    segment_right = g['right_pct'].max().reset_index().rename(columns={'right_pct': 'segment_right_pct'})
    segment_bottom = g['bottom_pct'].max().reset_index().rename(columns={'bottom_pct': 'segment_bottom_pct'})
    df = df.join(segment_left, on="block", how="left", rsuffix="_other")
    df = df.join(segment_top, on="block", how="left", rsuffix="_other")
    df = df.join(segment_right, on="block", how="left", rsuffix="_other")
    df = df.join(segment_bottom, on="block", how="left", rsuffix="_other")
    df = df[['file', 'page_width', 'page_height', 'left', 'top', 'right', 'bottom', 'center', 'middle', 'left_pct', 'top_pct', 'right_pct', 'bottom_pct', 'center_pct', 'middle_pct', 'block', 'segment', 'span', 'tag', 'segment_left_pct', 'segment_top_pct', 'segment_right_pct', 'segment_bottom_pct']]
    return df


def is_before(d):
    return sum(d) < 0.0


def update_hash(edges, df, idx, hsh, threshold, column='left_pct'):
    found = False
    dest = df.iloc[idx]
    dest_x = {'x': [dest['left_pct'], dest['right_pct'], dest['top_pct'], dest['bottom_pct']]}
    for k, v in hsh.items():
        src = df.iloc[k]
        if abs(dest[column]-src[column]) <= threshold:
            hsh[idx] = v
            found = True
            break
    if found:
        for k, v in hsh.items():
            if hsh[idx] == v:
                src = df.iloc[k]
                src_x = {'x': [src['left_pct'], src['right_pct'], src['top_pct'], src['bottom_pct']]}
                d = get_distance(src_x, dest_x)
                edges.append((k, idx, d, get_weight(d, 'v'), column[0] if is_before(d) else column[0] + '2'))
    if not found:
        vals = hsh.values()
        if not vals: hsh[idx] = 0
        else: hsh[idx] = max(vals)+1
    return hsh, edges


def add_clusters(row, lhash, chash, rhash, thash, mhash, bhash):
    idx = row.name
    row['left'] = lhash[idx]
    row['center'] = chash[idx]
    row['right'] = rhash[idx]
    row['top'] = thash[idx]
    row['middle'] = mhash[idx]
    row['bottom'] = bhash[idx]
    return row


def find_clusters(dframe: pd.DataFrame):
    """
    :param dframe: dataframe representing a Datacap file that has been segmented
    :return: dataframe with clusters of horizontal/vertical alignments
    """
    threshold = 0.01
    # pages = [1] # dframe['page'].unique()
    # dfs = []
    # for page in pages:
    df = dframe  # dframe[dframe['page'] == page].reset_index()
    lhash, chash, rhash = {}, {}, {}
    thash, mhash, bhash = {}, {}, {}
    edges = []
    for row in tqdm(df.itertuples()):
        idx = row.Index
        lhash, edges = update_hash(edges, df, idx, lhash, threshold, 'left_pct')
        chash, edges = update_hash(edges, df, idx, chash, threshold, 'center_pct')
        rhash, edges = update_hash(edges, df, idx, rhash, threshold, 'right_pct')
        thash, edges = update_hash(edges, df, idx, thash, threshold, 'top_pct')
        mhash, edges = update_hash(edges, df, idx, mhash, threshold, 'middle_pct')
        bhash, edges = update_hash(edges, df, idx, bhash, threshold, 'bottom_pct')
    # df.apply(lambda row: q    (df, row, lhash, chash, rhash, thash, mhash, bhash, threshold), axis=1) 
    # df = df.progress_apply(lambda row: add_clusters(row, lhash, chash, rhash, thash, mhash, bhash), axis=1)
    # df['left'] = df.index.to_series().apply(lambda x: lhash[x])
    # df['center'] = df.index.to_series().apply(lambda x: chash[x])
    # df['right'] = df.index.to_series().apply(lambda x: rhash[x])
    # df['top'] = df.index.to_series().apply(lambda x: thash[x])
    # df['middle'] = df.index.to_series().apply(lambda x: mhash[x])
    # df['bottom'] = df.index.to_series().apply(lambda x: bhash[x])
    # dfs.append(df)
    # dff = pd.concat(dfs, ignore_index=True)
    return df, edges


def get_distance(node1, node2):
    return [node1['x'][0] - node2['x'][0],
            node1['x'][1] - node2['x'][1],
            node1['x'][2] - node2['x'][2],
            node1['x'][3] - node2['x'][3]
            ]


def calculate_vt_idx(page_width, page_height, left, top):
    l = left * 224 / page_width
    t = top * 224 / page_height
    l_offset = l % 16
    t_offset = t % 16
    vt_idx = t_offset * 14 + l_offset
    return vt_idx


def is_vertical(l):
    x = [abs(k) for k in l]
    m = min(x)
    i = x.index(m)
    return i < 2


def get_weight(l, orientation):
    d = math.sqrt(l[0]**2 + l[1]**2 + l[2]**2 + l[3]**2)
    if orientation == 'v' and is_vertical(l): d = d*32.0
    if orientation == 'h' and not is_vertical(l): d = d*64.0
    return d


def add_edge(G, idx, j, label):
    d = get_distance(G.nodes[idx], G.nodes[j])
    G.add_edge(idx, j, edge_attr=[-abs(x) for x in d], weight=get_weight(d, 'v'), label=label)
    return True


def find_communities_louvain(g):
    comms = nx.community.louvain_communities(g, seed=37, resolution=0.3)
    return comms


def find_communities_leiden(g):
    g2 = g # ig.Graph.from_networkx(g)
    # if len(g2.es) > 0: print(g2.es[0])
    part = leidenalg.ModularityVertexPartition(g2, weights='weight')
                                               # initial_membership=np.random.choice(8, g.number_of_nodes()))
    # optimiser.consider_empty_community = False
    optimiser.optimise_partition(part)
    comms = part.membership
    # comms = leidenalg.find_partition(g2, leidenalg.ModularityVertexPartition, weights='weight', max_comm_size=2)
    return comms


def create_edge(src, dest):
    d = get_distance(src, dest)
    if src.name <= dest.name: 
        src['edge'] = None
        src['distance'] = None
        src['weight'] = None
        src['label'] = None
        return src
    if src['left'] == dest['left']: 
        src['edge'] = (src.name, dest.name)
        src['distance'] = d
        src['weight'] = get_weight(d, 'v')
        src['label'] = 'l'
        return src
    if src['center'] == dest['center']: 
        src['edge'] = (src.name, dest.name)
        src['distance'] = d
        src['weight'] = get_weight(d, 'v')
        src['label'] = 'c'
        return src
    if src['right'] == dest['right']: 
        src['edge'] = (src.name, dest.name)
        src['distance'] = d
        src['weight'] = get_weight(d, 'v')
        src['label'] = 'r'
        return src
    if src['top'] == dest['top']: 
        src['edge'] = (src.name, dest.name)
        src['distance'] = d
        src['weight'] = get_weight(d, 'v')
        src['label'] = 't'
        return src
    if src['middle'] == dest['middle']: 
        src['edge'] = (src.name, dest.name)
        src['distance'] = d
        src['weight'] = get_weight(d, 'v')
        src['label'] = 'm'
        return src
    if src['bottom'] == dest['bottom']: 
        src['edge'] = (src.name, dest.name)
        src['distance'] = d
        src['weight'] = get_weight(d, 'v')
        src['label'] = 'b'
        return src
    src['edge'] = None
    src['distance'] = None
    src['weight'] = None
    src['label'] = None
    return src


def add_nodes(row, train, get_label):
    # chars = [ord(c) for c in str(row['span'])]
    # chars = chars[:chars_len]
    # chars = chars + [0]*(chars_len-len(chars))
    # tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(tokenizer.cls_token + ' ' + str(row['span'])))
    idx = row.name
    block = int(row['block'])
    segment = int(segment_label_to_idx[row['segment']])
    s = str(row['span'])
    tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(s))
    tokens = tokens[:tokens_len]
    l = len(tokens)
    tokens = tokens + [tokenizer.pad_token_id]*(tokens_len-l)
    mask = [1]*l + [0]*(tokens_len-l)
    vt_idx = calculate_vt_idx(row['page_width'], row['page_height'], row['left'], row['top'])
    # x = pos+chars+tokens
    # print('coords', row['bottom'], row['top'], row['left'], row['right'], s, row['page_height'], row['page_width'])
    pos = [row['left_pct'], row['right_pct'], row['top_pct'], row['bottom_pct'], abs(row['bottom_pct']-row['top_pct']), abs(row['right_pct']-row['left_pct'])/(1.+len(s))]
    x = pos + tokens

    if train:
        y = get_label(row['tag'])
    else:
        y = 0
    row['x'] = x
    row['y'] = y
    row['mask'] = mask
    row['block'] = block
    row['segment'] = segment
    row['visual_token_idx'] = vt_idx
    row['image_path'] = row['file'] 
    return row


def add_edges(df, row):
    # lc, cc, rc, tc, mc, bc = row['left'], row['center'], row['right'], row['top'], row['middle'], row['bottom']
    dff = df.apply(lambda r: create_edge(r, row), axis=1)
    row['edges'], row['edge_attrs'], row['weights'], row['labels'] = dff['edge'], dff['distance'], dff['weight'], dff['label']
    return row


def get_dist(node1, node2):
    return [node1['left_pct'] - node2['left_pct'],
            node1['right_pct'] - node2['right_pct'],
            node1['top_pct'] - node2['top_pct'],
            node1['bottom_pct'] - node2['bottom_pct'],
            node1['bottom_pct']-node1['top_pct']-node2['bottom_pct']+node2['top_pct'],
            ((node1['right_pct']-node1['left_pct'])/(1.+len(node1['span'])))-((node2['right_pct']-node2['left_pct'])/(1.+len(node2['span'])))
            ]


def add_alignments(src, dest, threshold=0.01):
    if src.name == dest.name: return None
    d = get_dist(src, dest)
    if abs(src['left_pct'] - dest['left_pct']) < threshold: 
        # edge = {'edge': (src.name, dest.name), 'distance': d, 'weight': get_weight(d, 'v'), 'label': 'l'}
        edge = (src.name, dest.name, d, get_weight(d, 'v'), 'l' if is_before(d) else 'l2')
        return edge
    if abs(src['center_pct'] - dest['center_pct']) < threshold: 
        # edge = {'edge': (src.name, dest.name), 'distance': d, 'weight': get_weight(d, 'v'), 'label': 'c'}
        edge = (src.name, dest.name, d, get_weight(d, 'v'), 'c' if is_before(d) else 'c2')
        return edge
    if abs(src['right_pct'] - dest['right_pct']) < threshold: 
        # edge = {'edge': (src.name, dest.name), 'distance': d, 'weight': get_weight(d, 'v'), 'label': 'r'}
        edge = (src.name, dest.name, d, get_weight(d, 'v'), 'r' if is_before(d) else 'r2')
        return edge
    if abs(src['top_pct'] - dest['top_pct']) < threshold: 
        # edge = {'edge': (src.name, dest.name), 'distance': d, 'weight': get_weight(d, 'h'), 'label': 't'}
        edge = (src.name, dest.name, d, get_weight(d, 'h'), 't' if is_before(d) else 't2')
        return edge
    if abs(src['middle_pct'] - dest['middle_pct']) < threshold: 
        # edge = {'edge': (src.name, dest.name), 'distance': d, 'weight': get_weight(d, 'h'), 'label': 'm'}
        edge = (src.name, dest.name, d, get_weight(d, 'h'), 'm' if is_before(d) else 'm2')
        return edge
    if abs(src['bottom_pct'] - dest['bottom_pct']) < threshold: 
        # edge = {'edge': (src.name, dest.name), 'distance': d, 'weight': get_weight(d, 'h'), 'label': 'b'}
        edge = (src.name, dest.name, d, get_weight(d, 'h'), 'b' if is_before(d) else 'b2')
        return edge
    return None


def sanitize_edges(df):
    return [x for _,x in df.items() if x is not None]


def find_alignments(df):
    edges = df.progress_apply(lambda src: sanitize_edges(df.apply(lambda dest: add_alignments(src, dest), axis=1)), axis=1)
    es = [x for _,e in edges.items() for x in e]
    return df, es


def to_graph(dframe: pd.DataFrame, get_label, train):
    """
    :param dframe: dataframe representing a Datacap file that has been segmented
    :return: Networkx graph representation of dframe, with each segment forming a node and edges representing horizontal or vertical alignments
    """
    # try:
    df = dframe.sort_values(by=['segment_bottom_pct', 'segment_left_pct'], ascending=True, ignore_index=True)
    dff, es = find_alignments(df)
    n_vertices = df.shape[0]
    dfff = dff.progress_apply(lambda row: add_nodes(row, train, get_label), axis=1)
    edges = [(e[0], e[1]) for e in es]
    edge_attrs = [e[2] for e in es]
    weights = [e[3] for e in es]
    labels = [e[4] for e in es]
    assert len(edges) == len(edge_attrs)
    assert len(edges) == len(weights)
    assert len(edges) == len(labels)
    G = nx.Graph()
    for row in dfff.itertuples():
        G.add_node(row.Index, text=row.span, y=row.y, x=row.x, mask=row.mask, comm=0, block=row.block, segment=row.segment, visual_token_idx=row.visual_token_idx, image_path=row.image_path)
    for e in es:
        G.add_edge(e[0], e[1], edge_attr=e[2], weight=e[3], label=e[4], peer=1)
    return [G]


def to_graph_relations(entities, relations, image_path):
    im = Image.open(image_path)
    w, h = im.size
    G = nx.Graph()
    node_to_pos = {}
    for idx, entity in enumerate(entities):
        block = int(entity.block)
        segment = 1
        s = str(entity.text)
        tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(s))
        tokens = tokens[:tokens_len]
        l = len(tokens)
        tokens = tokens + [tokenizer.pad_token_id]*(tokens_len-l)
        mask = [1]*l + [0]*(tokens_len-l)
        vt_idx = calculate_vt_idx(w, h, entity.left, entity.top)
        # x = pos+chars+tokens
        # print('coords', row['bottom'], row['top'], row['left'], row['right'], s, row['page_height'], row['page_width'])
        left, top, right, bottom = entity.left, entity.top, entity.right, entity.bottom
        pos = [left/w, right/w, top/h, bottom/h, abs(bottom/h-top/h), abs(right/w-left/w)/(1.+len(s))]
        node_to_pos[idx] = pos
        x = pos + tokens
        G.add_node(idx, text=entity.text, y=1, x=x, mask=mask, comm=0, block=block, segment=segment, visual_token_idx=vt_idx, image_path=image_path)
    for r in relations:
        d = [x-y for x,y in zip(node_to_pos[r[0]], node_to_pos[r[1]])]
        G.add_edge(r[0], r[1], edge_attr=[-abs(x) for x in d[:4]], weight=1, label='l', peer=1)
    return [G]


def to_graph_beta(dframe: pd.DataFrame, get_label, train, w,h, graph_type='beta', n_comms=10, hierarchical=False):
    """
    :param dframe: dataframe representing a Datacap file that has been segmented
    :return: Networkx graph representation of dframe, with each segment forming a node and edges representing horizontal or vertical alignments
    """
    df = find_clusters(dframe)
    pages = [1] # df['page'].unique()
    # draw = ImageDraw.Draw(im) 
    for page in pages:
        # df = dframe[dframe['page'] == page].reset_index()
        # Make the networkx graph
        G = nx.Graph()
        sparse_node =[]
        node_map = {}
        # sort the dataframe by reading-order (top, then left)
        # df = df.sort_values(by=['top_pct', 'left_pct'], ascending=True)
        # populate the graph with nodes and edges
        count = 0
        for idx, row in df.iterrows():
            block = int(row['block'])
            s = str(row['span'])
            tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(s))
            tokens = tokens[:tokens_len]
            l = len(tokens)
            tokens = tokens + [tokenizer.pad_token_id]*(tokens_len-l)
            mask = [1]*l + [0]*(tokens_len-l)
            vt_idx = calculate_vt_idx(row['page_width'], row['page_height'], row['left'], row['top'])
            # x = pos+chars+tokens
            # print('coords', row['bottom'], row['top'], row['left'], row['right'], s, row['page_height'], row['page_width'])
            pos = [row['left_pct'], row['right_pct'], row['top_pct'], row['bottom_pct'], abs(row['bottom_pct']-row['top_pct']), abs(row['right_pct']-row['left_pct'])/(1.+len(s))]
            x = pos + tokens


            if train:
                y = get_label(row['tag'])
            else:
                y = 0
            G.add_node(idx, y=y, x=x, mask=mask, block=block, visual_token_idx=vt_idx, image_path=row['file'])
            sparse_node.append([row['left_pct']*w,row['top_pct']*h])
            node_map[count] = idx
            count +=1
            sparse_node.append([row['left_pct']*w,row['bottom_pct']*h])
            node_map[count] = idx
            count +=1
            sparse_node.append([row['right_pct']*w,row['top_pct']*h])
            node_map[count] = idx
            count +=1
            sparse_node.append([row['right_pct']*w,row['bottom_pct']*h])
            node_map[count] = idx
            count +=1

        tri = Delaunay(np.array(sparse_node))
        potential_edges = tri.simplices
        vor = Voronoi(np.array(sparse_node))
        to_be_remove = [list(points) for points in vor.ridge_points]
        # print(to_be_remove)
        for tri in potential_edges:
            # print(tri)
            a,b,c = int(tri[0]),int(tri[1]),int(tri[2])
            
            aa,bb,cc = node_map[a],node_map[b], node_map[c]
            if not remove(a, b, to_be_remove, node_map):
                add_edge(G, aa, bb, 'beta')
                # draw.line((sparse_node[a][0], sparse_node[a][1], sparse_node[b][0], sparse_node[b][1]), fill='grey') 
            if not remove(a, c, to_be_remove, node_map):
                add_edge(G, aa, cc, 'beta')
                # draw.line((sparse_node[a][0], sparse_node[a][1], sparse_node[c][0], sparse_node[c][1]), fill='grey') 
            if not remove(b, c, to_be_remove, node_map):
                add_edge(G, bb, cc, 'beta')
                # draw.line((sparse_node[b][0], sparse_node[b][1], sparse_node[c][0], sparse_node[c][1]), fill='grey') 
        comms = find_communities_leiden(G, n_comms, hierarchical)
        node_to_comm = {comm:s for comm, s in enumerate(comms)}
        # node_to_comm = {node:comm for comm, s in enumerate(comms) for node in s}
        for node in G.nodes:
            if graph_type == 'beta':
                G.nodes[node]['comm'] = node_to_comm[node]
            if graph_type == 'beta_comm0':
                G.nodes[node]['comm'] = 0
            # G.nodes[node]['comm'] = 0
        for edge in G.edges:
            src, tgt = edge[0], edge[1]
            
            if graph_type == 'beta':
                c1 = node_to_comm[src]
                c2 = node_to_comm[tgt]
            elif graph_type == 'beta_comm0':
                c1 = 0
                c2 = 0
            if c1 == c2: G[src][tgt]['peer'] = 1
            else: G[src][tgt]['peer'] = 0
        yield G


def remove(p1,p2, vors, node_map):
    p11, p22 = node_map[p1], node_map[p2]
    if p11 == p22:
        return True
    return False