import os
import glob
import random

import requests
from multiprocessing import Pool

import json


tokens = ['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', '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', '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','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', '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',
          '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', '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', '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',
          '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', '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', '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',
          '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']


def cmp(elem):
    return elem[0]


def filter_tokens(length, tokens, tags):
    indices = []
    tokens = tokens.split(' ')
    tags = tags.split(' ')
    for i in range(1, length):
        if str(tokens[i]).startswith('@'):
            indices.append(i)
        elif str(tokens[i]).isalnum() and not str(tokens[i]).islower():
            if str(tags[i]).startswith("NN"):
                # if str(tokens[i]) == 'file_name' or str(tokens[i]) == 'version':
                #     continue
                indices.append(i)
            else:
                before = i>0 and str(tokens[i-1])=="'"
                after = i+1<len(tokens) and str(tokens[i+1]) == "'"
                if before and after:
                    indices.append(i)

    return indices, tokens


def request(url):
    url = url.replace('https://github.com/', 'https://api.github.com/repos/').replace('/commit/', '/commits/')
    try:
        headers = {
            'Accept': 'application/vnd.github.v3.raw',
            "Authorization": f"token {tokens[random.randint(0,50)%len(tokens)]}"
        }
        response = requests.get(url, headers=headers)
        if response.status_code == 200:
            commit_details = response.json()
            print(f"{url} success retrieve")
            files = commit_details['files']
            return files
        else:
            print("Failed to retrieve commit details")
            return None
    except Exception as e:
        print(e)
        return None


def search_in_patches(url, indices, tokens):
    patches = []
    files = request(url)
    while files is None:
        files = request(url)
    for file in files:
        if 'patch' in file.keys():
            patch = file['patch']
            patches.append(patch)
    fount_indices = []
    found_tokens = []
    for index in indices:
        for patch in patches:
            if str(patch).find(tokens[index]) > -1:
                if index>0 and index<len(tokens)-1 and str(tokens[index-1])=="'" and str(tokens[index+1])=="'":
                    found_tokens.append("'" + str(tokens[index]) + "'")
                else:
                    found_tokens.append(tokens[index])
                fount_indices.append(index)
                break

    return fount_indices, list(set(found_tokens))


def escape(message, replacement):
    start = 0
    escapes = []
    index = str(message).find(replacement, start, len(message))
    while index > -1:
        escapes.append([index, index + len(replacement)])
        start = index + len(replacement)
        index = str(message).find(replacement, start, len(message))
    return escapes


def get_unreplacable(message, replacement):
    unreplacable_indices = []
    start = 0
    index = str(message).find(replacement, start, len(message))
    while index > -1:
        start = index + len(replacement)
        for i in range(index, start):
            unreplacable_indices.append(i)
        index = str(message).find(replacement, start, len(message))
    return unreplacable_indices


def replace_tokens(message, tokens):
    unreplacable = []
    replacements = ['<file_name>', '<version>', '<url>', '<enter>', '<tab>','<issue_link>', '<pr_link>', '<otherCommit_link>','<method_name>']
    for replacement in replacements:
        unreplacable += get_unreplacable(message, replacement)

    # find out start and end index of replaced tokens
    locations = []
    for t in tokens:
        end = 0
        while end < len(message):
            start = str(message).find(t, end, len(message))
            if start == -1:
                break
            end = start + len(t)
            before = start > 0 and str(message[start - 1]).isalnum()
            after = end < len(message) and str(message[end]).isalnum()
            if not before and not after:
                locations.append([start, end])

    # 合并互相包含的被替换token的区间
    locations.sort(key=cmp)
    i = 0
    while i < len(locations) - 1:
        if locations[i][1] > locations[i + 1][0]:
            if locations[i][0] == locations[i + 1][0]:
                if locations[i][1] < locations[i + 1][1]:
                    locations.pop(i)
                elif locations[i][1] > locations[i + 1][1]:
                    locations.pop(i + 1)
            elif locations[i][0] < locations[i + 1][0] and locations[i][1] >= locations[i + 1][1]:
                locations.pop(i + 1)
        else:
            i += 1

    # merge continuous replaced tokens
    new_locations = []
    i = 0
    start = -1
    while i < len(locations):
        if start < 0:
            start = locations[i][0]
        if i < len(locations) - 1 and locations[i + 1][0] - locations[i][1] < 2:
            i += 1
            continue
        else:
            end = locations[i][1]
            new_locations.append([start, end])
            start = -1
            i += 1

    # replace tokens in message with <file_name>
    end = 0
    new_message = ""
    for location in new_locations:
        start = location[0]
        new_message += message[end:start]
        new_message += "<iden>"
        end = location[1]
    new_message += message[end:len(message)]

    return new_message


def find_files(start_path='.'):
    file_paths = []
    """遍历start_path下的所有文件夹，寻找所有.pickle文件"""
    for root, dirs, files in os.walk(start_path):
        for file in glob.glob(os.path.join(root, '*.jsonl')):
            file_paths.append(file)
    return file_paths


# todo:Added #finished(long nanos, Description description),其中的方法名无法识别
# if __name__ == "__main__":
#     input_path = './tmp_step2_mcmd/' + lang
#     output_path = './tmp_step3_mcmd/' + lang
#
#     file_paths = find_files(input_path)
#
#     for file_path in file_paths:
#
#         data = json.load(open(file_path))
#         new_samples = []
#         for row in data:
#             url = "https://github.com/" + row['repo'].replace('\n', '') + '/commit/' + row['sha'].replace('\n', '')
#             new_message = row['msg']
#             length = row['allennlp_len']
#             tokens = row['allennlp_tokens']
#             tags = row['allennlp_tags']
#             if len(new_message) > 0:
#                 indices, tokens = filter_tokens(length, tokens, tags)
#                 if len(indices) > 0:
#                     fount_indices, found_tokens = search_in_patches(url, indices, tokens)
#                     if len(fount_indices) > 0:
#                         new_message = replace_tokens(new_message, found_tokens)
#
#             new_message.replace('<enter>', '$enter').replace('<tab>', '$tab').\
#             replace('<url>', '$url').replace('<version>', '$versionNumber')\
#             .replace('<pr_link>','$pullRequestLink>').replace('<issue_link >','$issueLink')\
#             .replace('<otherCommit_link>','$otherCommitLink').replace("<method_name>","$methodName")\
#             .replace("<file_name>","$fileName").replace("<iden>","$token")
#
#             row['msg'] = new_message
#
#             new_samples.append(row)
#
#         if not os.path.exists(output_path):
#             os.makedirs(output_path)
#         dir_path = os.path.join(output_path, file_path.split('/')[-1])
#         with open(dir_path, "w") as file:
#             json.dump(new_samples, file)


def filter_data(rows):
    new_samples = []
    for row in rows:
        url = "https://github.com/" + row['owner'].replace('\n', '')  + '/' + row['repo']+ '/commit/' + row['sha'].replace('\n', '')
        new_message = row['commit_message']
        length = row['allennlp_len']
        tokens = row['allennlp_tokens']
        tags = row['allennlp_tags']
        if len(new_message) > 0:
            indices, tokens = filter_tokens(length, tokens, tags)
            if len(indices) > 0:
                fount_indices, found_tokens = search_in_patches(url, indices, tokens)
                if len(fount_indices) > 0:
                    new_message = replace_tokens(new_message, found_tokens)

        new_message.replace('<enter>', '$enter').replace('<tab>', '$tab').\
        replace('<url>', '$url').replace('<version>', '$versionNumber')\
        .replace('<pr_link>','$pullRequestLink>').replace('<issue_link >','$issueLink')\
        .replace('<otherCommit_link>','$otherCommitLink').replace("<method_name>","$methodName")\
        .replace("<file_name>","$fileName").replace("<iden>","$token")

        row['commit_message'] = new_message
        new_samples.append(row)
    return new_samples


if __name__ == "__main__":
    input_path = './tmp_step2_mcmd/new_train.jsonl'
    output_path = './tmp_step3_mcmd/new_train.jsonl'

    data = json.load(open(input_path))[15000:]

    num_processes = 8

    # 计算每个进程处理的数据子集的大小
    chunk_size = len(data) // num_processes + 1
    data_subsets = [data[i:i + chunk_size] for i in range(0, len(data), chunk_size)]

    with Pool(processes=num_processes) as pool:
        # 并发处理数据过滤
        results = pool.map(filter_data, data_subsets)

    # 合并结果
    filtered_data = [item for sublist in results for item in sublist]

    cur_data = json.load(open(output_path))
    cur_data.extend(filtered_data)

    with open(output_path, "w") as file:
        json.dump(cur_data, file)

