import os
import sys
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(base_dir)

import pickle
import json
import random
import torch

# Read all triples from the in_file
def get_ans_triples(in_file):
    ans_triples = []
    with open(in_file, "rb") as f:
        for line in f:
            a_line = json.loads(line)
            
            idx2ent = {}
            for entity in a_line['entityMentions']:
                st, ed = entity['offset']
                idx2ent[entity['text']] = (st, ed)

            ans_triple = []
            for rel in a_line['relationMentions']:
                ans_triple.append((idx2ent[rel['em1Text']], idx2ent[rel['em2Text']]))
            ans_triples.append(ans_triple)
    return ans_triples

# Remove the given fields in the file
def remove_fields(file_name, fields_list=[]):
    with open(file_name, "rb") as f:
        data = pickle.load(f)

    new_data = []
    for ins in data:
        ins['id'] = ins['id']
        for field in fields_list:
            if field in ins.keys():
                ins.pop(field)
        new_data.append(ins)
    
    with open(file_name, "wb") as f:
        pickle.dump(new_data, f)

# Generate a random graph with #n-1 edges besides self loop
def get_sparse_random_graph(sentLen):
    random_fw = [[0] * sentLen for i in range(sentLen)]
    for i in range(sentLen):
        random_fw[i][i] = 1

    indexes = list(range(sentLen*sentLen))
    for i in indexes:
        if i % (sentLen+1) == 0: #delete diagonal elements
            indexes.remove(i)
    selected = random.sample(indexes, sentLen-1)

    for index in selected:
        i = int(index / sentLen)
        j = index % sentLen
        random_fw[i][j] = 1

    return random_fw

# Generate a random full graph
def get_full_random_graph(sentLen):
    random_fw = [[0.0] * sentLen for i in range(sentLen)]
    for i in range(sentLen):
        random_fw[i][i] = 1.0
    indexes = list(range(sentLen*sentLen))
    for i in indexes:
        if i % (sentLen+1) == 0:
            indexes.remove(i)
    
    for index in indexes:
        i = int(index / sentLen)
        j = index % sentLen
        random_fw[i][j] = random.random()

    return random_fw

# Generate a more noisy random graph by decreasing the weight of relational edges
def get_noisy_full_random_graph(sentLen, ans_triple):
    adj_fw = get_full_random_graph(sentLen)

    for triple in ans_triple:
        e1_start, e1_end = triple[0]
        e2_start, e2_end = triple[1]
        for i in range(e1_start, e1_end):
            for j in range(e2_start, e2_end):
                adj_fw[i][j] = random.uniform(0.01, 0.1)

    return adj_fw

# Generate a full of 1 adj graph
def get_full_graph(sentLen):
    full_fw = [[1] * sentLen for i in range(sentLen)]
    return full_fw

# Generate a linear sequence adj graph
def get_linear_graph(sentLen):
    linear_fw = [[0] * sentLen for i in range(sentLen)]

    for i in range(sentLen):
        linear_fw[i][i] = 1
        if i + 1 < sentLen: 
            linear_fw[i][i+1] = 1

    return linear_fw

# Transform adj_fw to other adjacent graphs, such as linear graph 
def transform_to_new_graph(in_file_name, out_file_name, main_file=None):
    with open(in_file_name, "rb") as f:
        data = pickle.load(f)

    if main_file:
        ans_triples = get_ans_triples(main_file)

    new_data = []
    for i, ins in enumerate(data):
        print(f"Process ins{i}......")
        new_ins = {}
        sentLen = len(ins['adj_fw'])
        
        # Transfer to a new kind of graph
        #adj_fw = get_full_random_graph(sentLen)
        adj_fw = get_noisy_full_random_graph(sentLen, ans_triples[i]) 

        new_ins['id'] = ins['id']
        new_ins['adj_fw'] = adj_fw

        new_data.append(new_ins)

    with open(out_file_name, "wb") as f:
        pickle.dump(new_data, f)

    return new_data

def get_masked_graph(mask_file, adj_file, out_file):
    with open(mask_file, "rb") as f1, open(adj_file, "rb") as f2:
        masks = pickle.load(f1)
        adjs = pickle.load(f2)
        assert len(masks) == len(adjs)

    out_dir = os.path.dirname(out_file)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    with open(out_file, "wb") as f:
        for i in range(len(adjs)):
            adj = torch.tensor(adjs[i]['adj_fw'])
            mask = torch.tensor(masks[i]['adj_fw'])
            adjs[i]['adj_fw'] = (adj * mask).tolist()
        
        pickle.dump(adjs, f)

def get_dep_prune_mask(dep_file, gp_dep_file, out_file):
    with open(dep_file, "rb") as f1, open(gp_dep_file, "rb") as f2:
        dep = pickle.load(f1)
        gp_dep = pickle.load(f2)

    masks = []
    for adj1, adj2 in zip(dep, gp_dep):
        assert adj1['id'] == adj2['id']

        dep_adj, gp_dep_adj = adj1['adj_fw'], adj2['adj_fw']
        seq_len = len(dep_adj)
        mask = {}
        mask['id'] = adj1['id']
        mask['adj_fw'] = [[1] * seq_len for _ in range(seq_len)]
        
        dep_cnt, gp_dep_cnt = 0, 0
        prune_cnt = 0
        for i in range(seq_len):
            for j in range(seq_len):
                if i != j and dep_adj[i][j] == 1:
                    dep_cnt += 1
                if i != j and gp_dep_adj[i][j] == 1:
                    gp_dep_cnt += 1

                if dep_adj[i][j] == 1 and gp_dep_adj[i][j] == 0:
                    mask['adj_fw'][i][j] = 0
                    prune_cnt += 1
        
        masks.append(mask)
    
    out_dir = os.path.dirname(out_file)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    with open(out_file, "wb") as f:
        pickle.dump(masks, f)


if __name__=='__main__':
    random.seed(512)
    torch.manual_seed(512)
    
    dataset = '../WebNLG'
    train = dataset + "/data/train.json"
    dev = dataset + "/data/dev.json"
    test = dataset + "/data/test.json"

    in_train_adj = dataset + "/adj/dep/train.adj"
    in_dev_adj = dataset + "/adj/dep/dev.adj"
    in_test_adj = dataset + "/adj/dep/test.adj"
    
    structure = "random7"
    os.mkdir(dataset + "/adj/" + structure)
    out_train_adj = dataset + "/adj/" + structure + "/train.adj"
    out_dev_adj   = dataset + "/adj/" + structure + "/dev.adj"
    out_test_adj  = dataset + "/adj/" + structure + "/test.adj"

    transform_to_new_graph(in_train_adj, out_train_adj, train)
    transform_to_new_graph(in_dev_adj, out_dev_adj, dev)
    transform_to_new_graph(in_test_adj, out_test_adj, test)
