import os
import json
import torch
import argparse
import numpy as np
import torch.nn as nn

from tqdm import tqdm
from torch import optim
from IODataset import IODataset
from PathRanker import PathRanker
from TestDataset import TestDataset
from torch.utils.data import DataLoader


def collate_fn(batch):
    phrases = [item[0] for item in batch]
    questions = [item[1] for item in batch]
    candi_relation_local_names = [item[2] for item in batch]
    candi_relation_kge_ids = []
    candi_relation_kge_masks = []
    for item in batch:
        if len(item[3]) == 2:
            candi_relation_kge_ids.append(item[3])
            candi_relation_kge_masks.append([1.0, 1.0])
        else:
            candi_relation_kge_ids.append([item[3][0], item[3][0]])
            candi_relation_kge_masks.append([1.0, 0.0])
    labels = [item[4] for item in batch]
    positive_index = -1
    for i in range(len(batch)):
        if batch[i][4] == 1:
            positive_index = i
            break
    # relation_phrases, questions, candidate_relation_path_names, candidate_relation_path_ids, candidate_relation_path_masks, labels, pos_index
    return phrases, questions, candi_relation_local_names, candi_relation_kge_ids, candi_relation_kge_masks, labels, positive_index


def collate_fn_test(batch):
    phrases = [item[0] for item in batch]
    questions = [item[1] for item in batch]
    candi_relation_local_names = [item[2] for item in batch]
    candi_relation_kge_ids = []
    candi_relation_kge_masks = []
    for item in batch:
        if len(item[3]) == 2:
            candi_relation_kge_ids.append(item[3])
            candi_relation_kge_masks.append([1.0, 1.0])
        else:
            candi_relation_kge_ids.append([item[3][0], item[3][0]])
            candi_relation_kge_masks.append([1.0, 0.0])
    raw_relation_names = [item[4] for item in batch]
    # relation_phrases, questions, candidate_relation_path_names, candidate_relation_path_ids, candidate_relation_path_masks, candidate_relation_path_urls
    return phrases, questions, candi_relation_local_names, candi_relation_kge_ids, candi_relation_kge_masks, raw_relation_names


def evaluate(relations_predict, relations_goldstandard):
    p_relation = 0
    r_relation = 0
    predictions = relations_predict
    gold_relations = relations_goldstandard
    # predictions = [x[x.rfind('/') + 1:-1] for x in predictions]
    # gold_relations = [x[x.rfind('/') + 1:-1] for x in relations_goldstandard]
    numberSystemRelations = len(gold_relations)
    if numberSystemRelations == 0:
        p_relation = 100
        r_relation = 100
    else:
        if len(predictions) == 0:
            return p_relation, r_relation
        p_relation = (len(set(predictions) & set(gold_relations)))*1.0 / len(set(predictions))
        r_relation = (len(set(predictions) & set(gold_relations)))*1.0 / len(gold_relations)
    return p_relation, r_relation


if __name__ == '__main__':
    dataset_name = 'QALD9'
    sota_score = 0.0
    parser = argparse.ArgumentParser()
    parser.add_argument('--freeze_bert', default=True, action='store_true')
    parser.add_argument('--max_len', type=int, default=8)
    parser.add_argument('--hidden_size', type=int, default=768)
    parser.add_argument('--kge_size', type=int, default=200)
    parser.add_argument('--bert_size', type=int, default=768)
    parser.add_argument('--batch_size', type=int, default=30)
    parser.add_argument('--max_eps', type=int, default=200)
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--print_step', type=int, default=5)
    parser.add_argument('--lr', type=float, default=3e-4)
    parser.add_argument('--fusion_method', type=str, default='gate')
    parser.add_argument('--is_train', type=bool, default=True)
    parser.add_argument('--is_test', type=bool, default=True)
    if dataset_name == 'PathSQ':
        parser.add_argument('--train_data_path', default='../../experiments/train/LC-QuAD_train.json')
    else:
        parser.add_argument('--train_data_path', default='../../experiments/train/' + dataset_name + '_train.json')
    parser.add_argument('--test_data_path', default='../../experiments/test/' + dataset_name + '.json')
    parser.add_argument('--save_model_dict', default='/DATA_PATH/ImRL/experiments/models/model_base_' + dataset_name +'/')
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    parser.add_argument('--device', default=torch.device('cuda'))
    config = parser.parse_args()

    if not os.path.exists(config.save_model_dict):
        os.makedirs(config.save_model_dict)

    # model
    model = PathRanker(config)
    # mapping_params = list(map(id, model.phrase_kg_encoder1.parameters()))
    # mapping_params += list(map(id, model.phrase_kg_encoder2.parameters()))
    # mapping_params += list(map(id, model.phrase_kg_encoder3.parameters()))
    # rest_params = filter(lambda x: id(x) not in mapping_params, model.parameters())

    criterion_ce = nn.CrossEntropyLoss()
    criterion_mse = nn.MSELoss()
    opti = optim.Adam(model.parameters(), lr=config.lr)
    # opti = optim.Adam([
    #     {'params': model.phrase_kg_encoder1.parameters(), 'lr': 1e-3},
    #     {'params': model.phrase_kg_encoder2.parameters(), 'lr': 1e-3},
    #     {'params': model.phrase_kg_encoder3.parameters(), 'lr': 1e-3},
    #     {'params': rest_params, 'lr': 3e-4}
    # ])

    max_p, max_r, max_f1 = 0, 0, 0

    if config.is_train:
        # train
        train_dataset = IODataset(config=config, path=config.train_data_path)
        train_data_loader = DataLoader(dataset=train_dataset, shuffle=False, batch_size=config.batch_size,
                                       collate_fn=collate_fn)
        model.train()
        model.to(config.device)
        now_loss = 99
        cnt_early = 0
        total_loss = 0
        num_batch = 0
        for epoch in range(config.max_eps):
            print("========== Epoch {} ==========".format(epoch))
            print("========== Train ==========")
            model.train()
            for _index, (relation_phrases, questions, candidate_relation_path_names, candidate_relation_path_ids, candidate_relation_path_masks, labels, pos_index) in enumerate(tqdm(train_data_loader)):
                opti.zero_grad()
                labels = torch.tensor(labels).to(config.device)
                score, p_kg, r_kg_gold = model(relation_phrases, questions, candidate_relation_path_names, candidate_relation_path_ids, candidate_relation_path_masks, pos_index)
                loss = criterion_ce(score, torch.max(labels, dim=0)[1].reshape(-1)) + criterion_mse(p_kg, r_kg_gold)
                loss.backward()
                opti.step()
                total_loss += loss.mean().item()
                num_batch += 1
            total_loss = total_loss / num_batch
            print(total_loss)
            now_loss = total_loss

            if config.is_test:
                print("========== Test ==========")
                test_dataset = TestDataset(config=config, path=config.test_data_path)
                test_data_loader = DataLoader(dataset=test_dataset, shuffle=False, batch_size=config.batch_size,
                                         collate_fn=collate_fn_test)
                # model.load_state_dict(torch.load(config.save_model_path))
                model.eval()
                model.to(config.device)
                result = []
                for _index, (relation_phrases, questions, candidate_relation_path_names, candidate_relation_path_ids, candidate_relation_path_masks, candidate_relation_path_urls) in enumerate(tqdm(test_data_loader)):
                    pred, p_kg, r_kg_gold = model(relation_phrases, questions, candidate_relation_path_names, candidate_relation_path_ids, candidate_relation_path_masks, -1)
                    result.append(candidate_relation_path_urls[torch.max(pred, dim=1)[1].item()])

                with open('../../experiments/test/' + dataset_name + '.json', 'r') as f:
                    data_test = json.load(f)

                cal_results = []
                cnt = 0
                ps = []
                rs = []
                for item in data_test:
                    goldd = item['gold']
                    golds = []
                    for go in goldd:
                        if go[-2:] == '-1':
                            golds.append(go[:-2])
                        else:
                            golds.append(go)
                    predicts = []
                    for triple in item['triples']:
                        for rel in result[cnt]:
                            if rel[-2:] == '-1':
                                predicts.append(rel[:-2])
                            else:
                                predicts.append(rel)
                        cnt += 1
                    cal_results.append({
                        'gold': golds,
                        'predict': predicts
                    })
                    p, r = evaluate(predicts, golds)
                    if p == 100 or r == 100:
                        continue
                    ps.append(p)
                    rs.append(r)
                    # print(predicts, golds)
                macro_p = np.mean(ps)
                macro_r = np.mean(rs)
                if macro_p + macro_r == 0:
                    f1 = 0
                else:
                    f1 = 2 * macro_p * macro_r / (macro_p + macro_r)
                print("Precision:{}, Recall:{}, F1:{}".format(macro_p, macro_r, f1))
                if f1 > sota_score and f1 > max_f1:
                    max_f1 = max(max_f1, f1)
                    max_p = macro_p
                    max_r = macro_r
                    torch.save(model.state_dict(), config.save_model_dict + str(f1) + '.pt')

    print(max_p, max_r, max_f1)

# if config.is_test:
#     print("========== Test ==========")
#     test_dataset = IODataset(config=config, path=config.test_data_path)
#     test_data_loader = DataLoader(dataset=test_dataset, shuffle=False, batch_size=config.batch_size,
#                                   collate_fn=collate_fn)
#     # model.load_state_dict(torch.load(config.save_model_path))
#     # test
#     tp = 0
#     fp = 0
#     tn = 0
#     fn = 0
#     model.eval()
#     model.to(config.device)
#     for _index, (questions, candi_relations, labels) in enumerate(tqdm(test_data_loader)):
#         labels = torch.tensor(labels).to(config.device)
#         score = model(questions, candi_relations)
#         max_index = torch.max(score, dim=1)[1]
#         if labels[max_index].item() == 0:
#             fp += 1
#             fn += 1
#         else:
#             tp += 1
#             tn += 1
#         # pred = torch.max(pred, dim=1)[1].reshape(-1)
#         # for i in range(pred.shape[0]):
#         #     if pred[i].item() == 0 and labels[i].item() == 0:
#         #         tn += 1
#         #     elif pred[i].item() == 0 and labels[i].item() == 1:
#         #         fn += 1
#         #     elif pred[i].item() == 1 and labels[i].item() == 0:
#         #         fp += 1
#         #     elif pred[i].item() == 1 and labels[i].item() == 1:
#         #         tp += 1
#     print(tp, tn, fp, fn)
#
#     if tp+fp == 0 or tp+fn == 0:
#         print("Precision:{}, Recall:{}, F1:{}".format(0, 0, 0))
#         f1 = 0
#     else:
#         precision = tp*1.0 / (tp+fp)
#         recall = tp*1.0 / (tp+fn)
#         f1 = 2*precision*recall / (precision+recall)
#         max_f1 = max(max_f1, f1)
#         print("Precision:{}, Recall:{}, F1:{}".format(precision, recall, f1))
#     torch.save(model.state_dict(), config.save_model_dict + str(f1) + '.pt')
