import json
from file_utils import *
from utils_pdg import get_dependence_line


DATA_PATH = '../../datasets/v3/final.jsonl'
OUTPUT_PATH = './data'

def get_pdg_dependence(nodes_path, edges_path, change_lines, LEVEL_SLICE=1):
    maps = dict()
    invert = dict()
    fail_get_result = [list(), list(), maps, invert]
    if not change_lines:
        return fail_get_result
    if not is_path_exist(nodes_path) or not is_path_exist(edges_path):
        return fail_get_result
    with open(nodes_path) as f:
        nodes = json.load(f)
        nodes = [n for n in nodes if 'lineNumber' in n]
    with open(edges_path) as f:
        edges = json.load(f)
    if len(nodes) == 10 or len(edges) < 20:
        return fail_get_result
    lines_dependence = set()
    lines_code = list()
    c_lines = list()
    # get all change lines
    for line in change_lines:
        if len(lines_code) > 0 and line not in lines_code:
            continue
        c_lines.append(line)

    lines_dependence.update(get_dependence_line(change_lines, nodes, edges, LEVEL_SLICE))
    lines = sorted(list(lines_dependence))
    for idx, line in enumerate(lines):
        tmp = 0
        for c_line in c_lines:
            if c_line < line:
                tmp += 1
        maps[line] = line - tmp
        invert[line-tmp] = line
    return [lines, c_lines, maps, invert]

def parser_code_property_graph(row, LEVEL_SLICE=1):
    commit_id = row['commit_id']
    file_name = row['file_name'].replace('/', '_')
    after_nodes_path, after_edges_path = get_cpg_after(commit_id, file_name)
    after_pdg = get_pdg_dependence(
        after_nodes_path, after_edges_path, row['add_lines'], LEVEL_SLICE)
    before_nodes_path, before_edges_path = get_cpg_before(commit_id, file_name)
    before_pdg = get_pdg_dependence(
        before_nodes_path, before_edges_path, row['delete_lines'], LEVEL_SLICE)
    if len(after_pdg[1]) == 0 and len(before_pdg[1]) == 0:
        return
    source_before = row['before_code'].splitlines() if isinstance(
        row['before_code'], str) else list()
    source_after = row['after_code'].splitlines() if isinstance(
        row['after_code'], str) else list()
    source_after.insert(0, '')
    source_before.insert(0, '')
    un_change_a = [after_pdg[2][line] for line in after_pdg[0]]
    un_change_b = [before_pdg[2][line] for line in before_pdg[0]]
    un_change = set(un_change_a + un_change_b)
    un_change = sorted(list(un_change))
    diff_pdg = list()

    for line in un_change:
        # dung while thi dung hon la if
        while len(before_pdg[1]) and line in before_pdg[3]:
            if before_pdg[3][line] < before_pdg[1][0]:
                break
            ll = before_pdg[1][0]
            diff_pdg.append('-' + source_before[ll])
            before_pdg[1].pop(0)

        while len(after_pdg[1]) and line in after_pdg[3]:
            if after_pdg[3][line] < after_pdg[1][0]:
                break
            ll = after_pdg[1][0]
            diff_pdg.append('+' + source_after[ll])
            after_pdg[1].pop(0)
        if line in before_pdg[3]:
            ll = before_pdg[3][line]
            diff_pdg.append(' ' + source_before[ll])
        elif line in after_pdg[3]:
            ll = after_pdg[3][line]
            diff_pdg.append(' ' + source_after[ll])
        else:
            print('Parser pdg error:', row['commit_id'])

    while len(before_pdg[1]) > 0 or len(after_pdg[1]) > 0:
        l_b = before_pdg[1][0] if len(before_pdg[1]) else 0
        l_a = after_pdg[1][0] if len(after_pdg[1]) else 0
        if l_b > l_a:
            ll = before_pdg[1][0]
            diff_pdg.append('-' + source_before[ll])
            before_pdg[1].pop(0)
        else:
            ll = after_pdg[1][0]
            diff_pdg.append('+' + source_after[ll])
            after_pdg[1].pop(0)

    row[f'pdg_{LEVEL_SLICE}'] = '\n'.join(diff_pdg)
    # with open(ff, 'wb') as handle:
    #     pickle.dump(row, handle, protocol=pickle.HIGHEST_PROTOCOL)
    print(f"{row['commit_id']} finish")
    return row


def parser_pdg_code_change(LEVEL_SLICE=1):
    data = json.load(open(DATA_PATH))
    samples = []
    for row in data:
        res = parser_code_property_graph(row, LEVEL_SLICE)
        if res is None:
            continue
        samples.append(res)

    if not os.path.exists(OUTPUT_PATH):
        os.makedirs(OUTPUT_PATH)
    out_put = os.path.join(OUTPUT_PATH, f'pdg_{LEVEL_SLICE}.jsonl')
    with open(out_put, 'w') as file:
        json.dump(samples, file)


if __name__ == '__main__':
    parser_pdg_code_change(LEVEL_SLICE=4)