import json
import os
import torch
import argparse
import numpy as np
import torch.nn as nn

from tqdm import tqdm
from torch import optim
from PathRanker import PathRanker
from TestDataset import TestDataset
from torch.utils.data import DataLoader


def collate_fn(batch):
    phrases = [item[0] for item in batch]
    questions = [item[1] for item in batch]
    candi_relation_local_names = [item[2] for item in batch]
    candi_relation_kge_ids = []
    candi_relation_kge_masks = []
    for item in batch:
        if len(item[3]) == 2:
            candi_relation_kge_ids.append(item[3])
            candi_relation_kge_masks.append([1.0, 1.0])
        else:
            candi_relation_kge_ids.append([item[3][0], item[3][0]])
            candi_relation_kge_masks.append([1.0, 0.0])
    raw_relation_names = [item[4] for item in batch]
    # relation_phrases, questions, candidate_relation_path_names, candidate_relation_path_ids, candidate_relation_path_masks, candidate_relation_path_urls
    return phrases, questions, candi_relation_local_names, candi_relation_kge_ids, candi_relation_kge_masks, raw_relation_names


def evaluate(relations_predict, relations_goldstandard):
    p_relation = 0
    r_relation = 0
    predictions = relations_predict
    gold_relations = relations_goldstandard
    # predictions = [x[x.rfind('/') + 1:-1] for x in predictions]
    # gold_relations = [x[x.rfind('/') + 1:-1] for x in relations_goldstandard]
    numberSystemRelations = len(gold_relations)
    if numberSystemRelations == 0:
        p_relation = 100
        r_relation = 100
    else:
        if len(predictions) == 0:
            return p_relation, r_relation
        p_relation = (len(set(predictions) & set(gold_relations)))*1.0 / len(set(predictions))
        r_relation = (len(set(predictions) & set(gold_relations)))*1.0 / len(gold_relations)
    return p_relation, r_relation


if __name__ == '__main__':
    dataset_name = 'PathSQ'
    parser = argparse.ArgumentParser()
    parser.add_argument('--freeze_bert', default=True, action='store_true')
    parser.add_argument('--max_len', type=int, default=8)
    parser.add_argument('--hidden_size', type=int, default=768)
    parser.add_argument('--kge_size', type=int, default=200)
    parser.add_argument('--bert_size', type=int, default=768)
    parser.add_argument('--batch_size', type=int, default=30)
    parser.add_argument('--max_eps', type=int, default=100)
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--print_step', type=int, default=5)
    parser.add_argument('--lr', type=float, default=3e-4)
    parser.add_argument('--fusion_method', type=str, default='gate')
    parser.add_argument('--is_train', type=bool, default=False)
    parser.add_argument('--is_test', type=bool, default=True)
    parser.add_argument('--test_data_path', default='../../experiments/test/' + dataset_name + '.json')
    parser.add_argument('--save_model_path', default='/DATA_PATH/ImRL/experiments/models/model_base_' + dataset_name +'.pt')
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    parser.add_argument('--device', default=torch.device('cuda'))
    config = parser.parse_args()

    # model
    model = PathRanker(config)
    dataset = TestDataset(config=config, path=config.test_data_path)
    data_loader = DataLoader(dataset=dataset, shuffle=False, batch_size=config.batch_size, collate_fn=collate_fn)
    model.load_state_dict(torch.load(config.save_model_path))
    model.eval()
    model.to(config.device)
    result = []
    for _index, (relation_phrases, questions, candidate_relation_path_names, candidate_relation_path_ids,
                 candidate_relation_path_masks, candidate_relation_path_urls) in enumerate(tqdm(data_loader)):
        pred, p_kg, r_kg_gold = model(relation_phrases, questions, candidate_relation_path_names,
                                      candidate_relation_path_ids, candidate_relation_path_masks, -1)
        result.append(candidate_relation_path_urls[torch.max(pred, dim=1)[1].item()])

    with open('../../experiments/test/' + dataset_name + '.json', 'r') as f:
        data_test = json.load(f)

    cal_results = []
    cnt = 0
    ps = []
    rs = []
    for item in data_test:
        goldd = item['gold']
        golds = []
        for go in goldd:
            if go[-2:] == '-1':
                golds.append(go[:-2])
            else:
                golds.append(go)
        predicts = []
        for triple in item['triples']:
            for rel in result[cnt]:
                if rel[-2:] == '-1':
                    predicts.append(rel[:-2])
                else:
                    predicts.append(rel)
            cnt += 1
        cal_results.append({
            'gold': golds,
            'predict': predicts
        })
        p, r = evaluate(predicts, golds)
        if p == 100 or r == 100:
            continue
        ps.append(p)
        rs.append(r)
        # print(predicts, golds)
    macro_p = np.mean(ps)
    macro_r = np.mean(rs)
    f1 = 2 * macro_p * macro_r / (macro_p + macro_r)
    print(macro_p, macro_r, f1)

    with open('../../experiments/results/' + dataset_name + '.json', 'w') as f:
        json.dump(cal_results, f, indent=4)

    print(cnt)