import torch
import os
import glob
import torch.nn as nn
import numpy as np
import json
from transformers import BertTokenizer, BertModel
from multiprocessing import Pool


np.random.seed(0)
torch.manual_seed(0)
USE_CUDA = torch.cuda.is_available()
if USE_CUDA:
    torch.cuda.manual_seed(0)


class ModelConfig:
    batch_size = 64
    output_size = 2
    hidden_dim = 384
    n_layers = 2
    bidirectional = True
    drop_prob = 0.55
    use_cuda = USE_CUDA
    checkpoint_path = '../../checkpoint/bert-bilstm-classifer-checkpoint/bert_bilstm_what.pth'
    bert_path = '../../checkpoint/bert'
    data_input_path = './tmp_step4_mcmd'
    output_path = './tmp_step5_mcmd'
    src_path = './tmp_step4_mcmd'


class bert_lstm(nn.Module):
    def __init__(self, bertpath, hidden_dim, output_size, n_layers, bidirectional=True, drop_prob=0.5):
        super(bert_lstm, self).__init__()

        self.output_size = output_size
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim
        self.bidirectional = bidirectional

        self.bert = BertModel.from_pretrained(bertpath)
        for param in self.bert.parameters():
            param.requires_grad = True

        # LSTM layers
        self.lstm = nn.LSTM(768, hidden_dim, n_layers, batch_first=True, bidirectional=bidirectional)

        # dropout layer
        self.dropout = nn.Dropout(drop_prob)

        # linear and sigmoid layers
        if bidirectional:
            self.fc = nn.Linear(hidden_dim * 2, output_size)
        else:
            self.fc = nn.Linear(hidden_dim, output_size)

        # self.sig = nn.Sigmoid()

    def forward(self, x, hidden):
        batch_size = x.size(0)
        x = self.bert(x)[0]
        lstm_out, (hidden_last, cn_last) = self.lstm(x, hidden)

        if self.bidirectional:
            hidden_last_L = hidden_last[-2]
            hidden_last_R = hidden_last[-1]
            hidden_last_out = torch.cat([hidden_last_L, hidden_last_R], dim=-1)
        else:
            hidden_last_out = hidden_last[-1]

        # dropout and fully-connected layer
        out = self.dropout(hidden_last_out)
        # print(out.shape)
        out = self.fc(out)
        return out

    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        number = 1
        if self.bidirectional:
            number = 2
        if (USE_CUDA):
            hidden = (weight.new(self.n_layers * number, batch_size, self.hidden_dim).zero_().float().cuda(),
                      weight.new(self.n_layers * number, batch_size, self.hidden_dim).zero_().float().cuda()
                      )
        else:
            hidden = (weight.new(self.n_layers * number, batch_size, self.hidden_dim).zero_().float(),
                      weight.new(self.n_layers * number, batch_size, self.hidden_dim).zero_().float()
                      )
        return hidden


tokenizer = BertTokenizer.from_pretrained(ModelConfig().bert_path)
model = bert_lstm(ModelConfig().bert_path, ModelConfig().hidden_dim, ModelConfig().output_size, ModelConfig().n_layers,
                      ModelConfig().bidirectional, ModelConfig().drop_prob)

model.load_state_dict(torch.load(ModelConfig().checkpoint_path, map_location=torch.device('cpu')))
if ModelConfig().use_cuda:
    model.cuda()


def find_files(start_path='.'):
    file_paths = []
    """遍历start_path下的所有文件夹，寻找所有.jsonl文件"""
    for root, dirs, files in os.walk(start_path):
        for file in glob.glob(os.path.join(root, '*.jsonl')):
            file_paths.append(file)
    return file_paths


def gen(rows):
    new_samples = []
    for row in rows:
        message = row['commit_message']
        new_message = message.replace('<enter>', '$enter').replace('<tab>', '$tab').replace('<url>', '$url') \
            .replace('<version>', '$version').replace('<pr_link>', '$pull request>').replace('<issue_link >', '$issue') \
            .replace('<otherCommit_link>', '$other commit').replace("<method_name>", "$method") \
            .replace("<file_name>", "$file").replace("<iden>", "$token")

        inputs = tokenizer(
            new_message,
            return_tensors='pt',
            pad_to_max_length=True,
            max_length=200,
            truncation=True
        )

        if ModelConfig().use_cuda:
            inputs = {k: v.to('cuda') for k, v in inputs.items()}

        with torch.no_grad():
            # 初始化隐藏状态
            hidden = model.init_hidden(batch_size=1)

            # 进行预测
            output = model(inputs['input_ids'], hidden)

        predicted_output = torch.sigmoid(output)  # 如果你的任务是分类，可能需要sigmoid或softmax
        pred = torch.argmax(predicted_output, dim=1).item()

        if pred == 1:
            print(f"diff-{row['diff_id']} is good message")
            new_samples.append(row)

    return new_samples


if __name__ == '__main__':
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"

    model_config = ModelConfig()

    file_paths = find_files(model_config.data_input_path)

    if not os.path.exists(model_config.output_path):
        os.makedirs(model_config.output_path)

    for file_path in file_paths:

        dir_path = os.path.join(model_config.src_path, file_path.split('/')[-1])
        src_samples = json.load(open(dir_path))
        src_map = {}

        for sample in src_samples:
            src_map[sample['diff_id']] = sample

        data = json.load(open(file_path))

        print(f"{file_path.split('/')[-1]}过滤前数据集长度：{len(data)}")

        filtered_data = gen(data)
        # num_processes = 1
        #
        # # 计算每个进程处理的数据子集的大小
        # chunk_size = len(data) // num_processes + 1
        # data_subsets = [data[i:i + chunk_size] for i in range(0, len(data), chunk_size)]
        #
        # with Pool(processes=num_processes) as pool:
        #     # 并发处理数据过滤
        #     results = pool.map(gen, data_subsets)

        # 合并结果
        # filtered_data = [item for sublist in results for item in sublist]

        print(f"{file_path.split('/')[-1]}过滤后数据集长度：{len(filtered_data)}")

        clean_samples = [src_map[d['diff_id']] for d in filtered_data]

        print(f"{file_path.split('/')[-1]}过滤后数据集长度：{len(clean_samples)}")

        dir_path = os.path.join(model_config.output_path, file_path.split('/')[-1])
        with open(dir_path, "w") as file:
            json.dump(clean_samples, file)

# if __name__ == '__main__':
#     # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
#
#     model_config = ModelConfig()
#
#     file_paths = find_files(model_config.data_input_path)
#
#     if not os.path.exists(model_config.output_path):
#         os.makedirs(model_config.output_path)
#
#     for file_path in file_paths:
#
#         dir_path = os.path.join(model_config.src_path, file_path.split('/')[-1])
#         tmp_step3_samples = json.load(open(dir_path))
#         tmp_step3_map = {}
#
#         for sample in tmp_step3_samples:
#             tmp_step3_map[sample['diff_id']] = sample
#
#         data = json.load(open(file_path))
#         new_data = []
#         data_map = {}
#         for row in data:
#             new_data.append(tmp_step3_map[row['diff_id']])
#             data_map[row['diff_id']] = row
#
#         print(f"{file_path.split('/')[-1]}过滤前数据集长度：{len(new_data)}")
#
#         num_processes = 8
#
#         # 计算每个进程处理的数据子集的大小
#         chunk_size = len(new_data) // num_processes + 1
#         data_subsets = [new_data[i:i + chunk_size] for i in range(0, len(new_data), chunk_size)]
#
#         with Pool(processes=num_processes) as pool:
#             # 并发处理数据过滤
#             results = pool.map(gen, data_subsets)
#
#         # 合并结果
#         filtered_data = [item for sublist in results for item in sublist]
#
#         print(f"{file_path.split('/')[-1]}过滤后数据集长度：{len(filtered_data)}")
#
#         clean_samples = [data_map[d['diff_id']] for d in filtered_data]
#
#         print(f"{file_path.split('/')[-1]}过滤后数据集长度：{len(clean_samples)}")
#
#         dir_path = os.path.join(model_config.output_path, file_path.split('/')[-1])
#         with open(dir_path, "w") as file:
#             json.dump(clean_samples, file)
