import json
import os
import glob
from nltk import word_tokenize
from nltk import pos_tag
# from allennlp.models.archival import load_archive
from allennlp.predictors import Predictor
# from allennlp_models.pretrained import load_predictor

lang = 'java'


def nltk_tag(message):
    words = word_tokenize(message)
    message = ' '.join(words)
    message = message.replace('< method_name >','<method_name>').replace('< file_name >', '<file_name>')\
        .replace('< url >', '<url>').replace('< version >', '<version>')
    words = message.split(' ')

    tags = pos_tag(words)
    tokens = [tag[0] for tag in tags]
    tags = [tag[1] for tag in tags]
    tokens = ' '.join(tokens)
    tags = ' '.join(tags)
    print(tokens)
    print(tags)
    return tags


def allennlp_tag(message, predictor):
    result = predictor.predict(message)
    tokens = result['tokens']
    tags = result['pos_tags']

    indices = []
    for i in range(len(tokens)):
        s = str(tokens[i])
        if s.startswith('file_name>') or s.startswith('version>') or s.startswith('url>') \
                or s.startswith('enter>') or s.startswith('tab>') or s.startswith('iden>') or s.startswith('method_name>')\
                or s.startswith('pr_link>') or s.startswith('issue_link>') or s.startswith('otherCommit_link>'):
            indices.append(i)
        elif s.endswith('<file_name') or s.endswith('<version') or s.endswith('<url') \
                or s.endswith('<enter') or s.endswith('<tab') or s.endswith('<iden') or s.endswith('<method_name')\
                or s.endswith('<pr_link') or s.endswith('<issue_link') or s.endswith('<otherCommit_link'):
            indices.append(i)

    new_tokens = []
    new_tags = []
    for i in range(len(tokens)):
        if i in indices:
            s = str(tokens[i])
            if s.startswith('file_name>'):
                s = s.replace('file_name>', '')
                new_tokens.append('file_name')
                new_tags.append('XX')
                new_tokens.append('>')
                new_tags.append('XX')
                new_tokens.append(s)
                new_tags.append('XX')
            elif s.startswith('method_name>'):
                s = s.replace('method_name>', '')
                new_tokens.append('method_name')
                new_tags.append('XX')
                new_tokens.append('>')
                new_tags.append('XX')
                new_tokens.append(s)
                new_tags.append('XX')
            elif s.startswith('version>'):
                s = s.replace('version>', '')
                new_tokens.append('version')
                new_tags.append('XX')
                new_tokens.append('>')
                new_tags.append('XX')
                new_tokens.append(s)
                new_tags.append('XX')
            elif s.startswith('url>'):
                s = s.replace('url>', '')
                new_tokens.append('url')
                new_tags.append('XX')
                new_tokens.append('>')
                new_tags.append('XX')
                new_tokens.append(s)
                new_tags.append('XX')
            elif s.startswith('enter>'):
                s = s.replace('enter>', '')
                new_tokens.append('enter')
                new_tags.append('XX')
                new_tokens.append('>')
                new_tags.append('XX')
                new_tokens.append(s)
                new_tags.append('XX')
            elif s.startswith('tab>'):
                s = s.replace('tab>', '')
                new_tokens.append('tab')
                new_tags.append('XX')
                new_tokens.append('>')
                new_tags.append('XX')
                new_tokens.append(s)
                new_tags.append('XX')
            elif s.startswith('iden>'):
                s = s.replace('iden>', '')
                new_tokens.append('iden')
                new_tags.append('XX')
                new_tokens.append('>')
                new_tags.append('XX')
                new_tokens.append(s)
                new_tags.append('XX')
            elif s.startswith('pr_link>'):
                s = s.replace('pr_link>', '')
                new_tokens.append('pr_link')
                new_tags.append('XX')
                new_tokens.append('>')
                new_tags.append('XX')
                new_tokens.append(s)
                new_tags.append('XX')
            elif s.startswith('issue_link>'):
                s = s.replace('issue_link>', '')
                new_tokens.append('issue_link')
                new_tags.append('XX')
                new_tokens.append('>')
                new_tags.append('XX')
                new_tokens.append(s)
                new_tags.append('XX')
            elif s.startswith('otherCommit_link>'):
                s = s.replace('otherCommit_link>', '')
                new_tokens.append('otherCommit_link')
                new_tags.append('XX')
                new_tokens.append('>')
                new_tags.append('XX')
                new_tokens.append(s)
                new_tags.append('XX')
            elif s.endswith('<file_name'):
                s = s.replace('<file_name', '')
                new_tokens.append(s)
                new_tags.append('XX')
                new_tokens.append('<')
                new_tags.append('XX')
                new_tokens.append('file_name')
                new_tags.append('XX')
            elif s.endswith('<method_name'):
                s = s.replace('<method_name', '')
                new_tokens.append(s)
                new_tags.append('XX')
                new_tokens.append('<')
                new_tags.append('XX')
                new_tokens.append('method_name')
                new_tags.append('XX')
            elif s.endswith('<version'):
                s = s.replace('<version', '')
                new_tokens.append(s)
                new_tags.append('XX')
                new_tokens.append('<')
                new_tags.append('XX')
                new_tokens.append('version')
                new_tags.append('XX')
            elif s.endswith('<url'):
                s = s.replace('<url', '')
                new_tokens.append(s)
                new_tags.append('XX')
                new_tokens.append('<')
                new_tags.append('XX')
                new_tokens.append('url')
                new_tags.append('XX')
            elif s.endswith('<enter'):
                s = s.replace('<enter', '')
                new_tokens.append(s)
                new_tags.append('XX')
                new_tokens.append('<')
                new_tags.append('XX')
                new_tokens.append('enter')
                new_tags.append('XX')
            elif s.endswith('<tab'):
                s = s.replace('<tab', '')
                new_tokens.append(s)
                new_tags.append('XX')
                new_tokens.append('<')
                new_tags.append('XX')
                new_tokens.append('tab')
                new_tags.append('XX')
            elif s.endswith('<iden'):
                s = s.replace('<iden', '')
                new_tokens.append(s)
                new_tags.append('XX')
                new_tokens.append('<')
                new_tags.append('XX')
                new_tokens.append('iden')
                new_tags.append('XX')
            elif s.endswith('<pr_link'):
                s = s.replace('<pr_link', '')
                new_tokens.append(s)
                new_tags.append('XX')
                new_tokens.append('<')
                new_tags.append('XX')
                new_tokens.append('pr_link')
                new_tags.append('XX')
            elif s.endswith('<issue_link'):
                s = s.replace('<issue_link', '')
                new_tokens.append(s)
                new_tags.append('XX')
                new_tokens.append('<')
                new_tags.append('XX')
                new_tokens.append('issue_link')
                new_tags.append('XX')
            elif s.endswith('<otherCommit_link'):
                s = s.replace('<otherCommit_link', '')
                new_tokens.append(s)
                new_tags.append('XX')
                new_tokens.append('<')
                new_tags.append('XX')
                new_tokens.append('otherCommit_link')
                new_tags.append('XX')
        else:
            new_tokens.append(tokens[i])
            new_tags.append(tags[i])
    tokens = new_tokens
    tags = new_tags
    length = len(tokens)

    new_tokens = []
    new_tags = []
    targets = ['file_name', 'version', 'url', 'enter', 'tab', 'iden', 'issue_link', 'pr_link', 'otherCommit_link','method_name']
    i = 0
    while i < length:
        if i < length-2 and tokens[i] == '<' and tokens[i+1] in targets and tokens[i+2] == '>':
            new_tokens.append(tokens[i] + tokens[i+1] + tokens[i+2])
            new_tags.append('XX')
            i += 3
        else:
            new_tokens.append(tokens[i])
            new_tags.append(tags[i])
            i += 1

    tokens = new_tokens
    tags = new_tags
    length = len(tokens)
    # 使用空格连接tokens
    tokens = ' '.join(tokens)
    tags = ' '.join(tags)
    print('----------------------------------------------------------------------')
    print(tokens)
    print(tags)
    # print(trees)
    return tokens, tags, length


def find_files(start_path='.'):
    file_paths = []
    """遍历start_path下的所有文件夹，寻找所有.pickle文件"""
    for root, dirs, files in os.walk(start_path):
        for file in glob.glob(os.path.join(root, '*.jsonl')):
            file_paths.append(file)
    return file_paths


if __name__ == "__main__":
    # -1代表cpu运行，需要用到句子中的短语结构就用constituent parse成分句法分析 ，而需要用到词与词之间的依赖关系就用dependency parse依存句法分析。
    # archive = load_archive('./tools/elmo-constituency-parser-2020.02.10.tar.gz', -1)
    # predictor = Predictor.from_archive(archive, 'constituency-parser')
    predictor = Predictor.from_path('./tools/elmo-constituency-parser-2020.02.10.tar.gz')
    # predictor = load_predictor("./tools/elmo-constituency-parser-2020.02.10.tar.gz")

    input_path = './tmp_step1_mcmd'
    output_path = './tmp_step2_mcmd'

    file_paths = find_files(input_path)

    for file_path in file_paths:

        data = json.load(open(file_path))
        new_samples = []
        for row in data:
            message = row['commit_message']
            tokens, tags, length = allennlp_tag(message, predictor)
            row['allennlp_len'] = length
            row['allennlp_tokens'] = tokens
            row['allennlp_tags'] = tags
            new_samples.append(row)
            print(row)

        if not os.path.exists(output_path):
            os.makedirs(output_path)
        dir_path = os.path.join(output_path, file_path.split('/')[-1])
        with open(dir_path, "w") as file:
            json.dump(new_samples, file)
