import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import json
import numpy as np
from utils import *
import random
from tqdm import tqdm
from sklearn.linear_model import LogisticRegression
from interpret import show
from interpret.glassbox import ExplainableBoostingClassifier

from group_lasso import LogisticGroupLasso
from sklearn.pipeline import Pipeline

from sklearn.metrics import f1_score, precision_score, recall_score

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--setting', choices=['lr', 'grpls', 'ebm'])

import torch
import torch.nn as nn


def get_rank_score(cls_prob, y):
    assert len(cls_prob) == len(y)

    mrr_list, hits_1_cnt, hits_2_cnt = [], 0, 0
    for i in range(len(cls_prob)):
        goals_dict = {id:prob for id,prob in enumerate(cls_prob[i])}
        target = goals_dict[y[i]]

        less_cnt = 0
        for k,v in goals_dict.items():
            if k != y[i] and v < target:
                less_cnt += 1
        rank = len(goals_dict) - less_cnt # worst case 
        assert rank >= 1 and rank <= len(goals_dict)

        if rank == 1:
            hits_1, hits_2 = 1, 1
        elif rank == 2:
            hits_1, hits_2 = 0, 1
        else:
            hits_1, hits_2 = 0, 0

        mrr_list.append(1.0/rank)
        hits_1_cnt += hits_1
        hits_2_cnt += hits_2

    return np.mean(mrr_list), 1.0*hits_1_cnt/len(y), 1.0*hits_2_cnt/len(y)

def check_triple(in_no, V_knw, id2embed_ent, id2embed_rel, id2rel, id2ent, noisy_or, threshold=None, show_infer_step=False, incre=False):
    V_knw_tmp = V_knw.copy()
    in_no_prob = []
    for knw in V_knw_tmp:
        if V_knw_tmp[knw] > 0:
            for idx_rel, vec_rel in enumerate(id2embed_rel):
                if incre and id2rel[idx_rel] not in ['并发症', '病因', '相关疾病']:
                    continue
                elif id2rel[idx_rel] not in ['focus_of', 'associated_with', 'temporally_related_to']:
                    continue
                else:
                    vec_sub = id2embed_ent[knw]
                    # vec_rel = id2embed_rel[rel]
                    vec1 = np.add(vec_sub, vec_rel)
                    vec2 = id2embed_ent[in_no]

                    CosSim = float(np.dot(vec1,vec2)/(np.linalg.norm(vec1)*np.linalg.norm(vec2))) * V_knw_tmp[knw]
                    # print(f'dist:{dist}')
                    if threshold is None:
                        in_no_prob.append(CosSim)
                    else:
                        if (CosSim >= threshold):
                            if show_infer_step:
                                print(f'{id2ent[knw]} + {id2rel[idx_rel]} -> {id2ent[in_no]}')
                            # return True, dist 
                            in_no_prob.append(CosSim)

    if len(in_no_prob) > 0:
        return True, sum_prob(in_no_prob) if noisy_or else max(in_no_prob) # noisy_or:  a + b - a * b , else max(prob_list)
    else:
        return False, 0

def extract_features(rules, id2symp, spec_relations, id2dise, ent2id, rel2id, embeddings, samples, threshold, noisy_or):
    symp2id = {j:i for i,j in id2symp.items()}
    dise2id = {j:i for i,j in id2dise.items()}
    id2rel = {j:i for i,j in rel2id.items()}
    id2ent = {j:i for i,j in ent2id.items()}

    id2embed_ent, id2embed_rel = embeddings.solver.values()

    sr2id, id2sr = {}, {}
    for s in symp2id:
        for r in spec_relations:
            sr2id[(ent2id[s],rel2id[r])] = len(sr2id)
            id2sr[len(id2sr)] = (ent2id[s],rel2id[r])
    
    # print(f'sr{sr2id}')
    
    features = []
    for i in tqdm(samples):
        rule_feature = get_rule_feature(id2dise, dise2id, ent2id, id2rel, id2ent, id2embed_ent, id2embed_rel, rules, i['symptoms'], threshold=threshold, noisy_or=noisy_or)
        # print(f'rule_feature:{rule_feature.shape}')
        features.append(rule_feature)
    
    print(f'features:{np.asarray(features).shape}')

    return features


def get_rule_feature(id2dise, dise2id, ent2id, id2rel, id2ent, id2embed_ent, id2embed_rel, rules, symptoms, threshold, noisy_or):
    rule_feature_map = []
    for r in rules:
        symptoms_ids = {ent2id[i]:{'True':1.0, 'False':-1.0}[j] for i,j in symptoms.items()}
        rule_feature = np.zeros(len(id2dise))
        premise = [ent2id[i] for i in rules[r][0]]
        conclusion = rules[r][1][0]
        confidence = rules[r][2]
        rule_feature[int(dise2id[conclusion])] = confidence
    
        coefficient = 1 # early reject = 0, rule fired = 1, otherwise, calculate via link prediction

        early_break = False
        for i in premise:
            if i in symptoms_ids and symptoms_ids[i] == -1:
                early_break = True
                break

        if early_break:
            coefficient = -1
        elif len(set(premise) - set(symptoms_ids.keys())) == 0: # all premises are satisfied
            coefficient = 1
        else: # do link prediction for premise
            for in_no in set(premise) - set(symptoms_ids.keys()):
                isTrue, prob = check_triple(in_no, symptoms_ids, id2embed_ent, id2embed_rel, id2rel, id2ent, threshold=threshold, noisy_or=noisy_or, show_infer_step=False)
                symptoms_ids[in_no] = prob
            coefficient = min([symptoms_ids[i] for i in premise])

        rule_feature_map.append(rule_feature * coefficient)
    return np.transpose(np.asarray(rule_feature_map))


def train_and_valid(model, X_train, y_train, X_test, y_test):
    model.fit(X_train, y_train)
    _, hits_1, _ = get_rank_score(model, X_test, y_test)
    return hits_1


def test_model(model, data):
    model.eval()

    y_test = []
    y_pred = []
    cls_prob = []
    for i in data:
        inputs, targets, rule_features= i
        inputs = torch.tensor(inputs)
        rule_features = torch.FloatTensor(rule_features) 
        y_test.extend(targets) 
        output = model(inputs, rule_features)
        cls_prob.append(list((output.data.detach().numpy()[0])))
        output = list(torch.max(output.data, 1).indices.detach().numpy())
        y_pred.extend(output)

    mrr, hits_1, hits_2 = get_rank_score(cls_prob, y_test)

    macro_f1 = f1_score(y_test, y_pred, average='macro')
    macro_p = precision_score(y_test, y_pred, average='macro')
    macro_r = recall_score(y_test, y_pred, average='macro')

    print(f'macro p: {macro_p}, r: {macro_r}, f1: {macro_f1}\n')


    hits_1_score = hits_1
    hits_2_score = hits_2
    mrr = mrr

    print(f'Hits@1: {hits_1_score}')  
    print(f'Hits@2: {hits_2_score}')  
    print(f'MRR: {mrr}')

    PERFORMANCE = {'macro_p':macro_p, 'macro_r':macro_r, 'macro_f1':macro_f1, 'Hits@1':hits_1_score, 'Hits@2':hits_2_score, 'MRR':mrr}


    with open(f'../PERFORMANCE_CogRepre_NeuroNet.json', 'w', encoding='utf-8') as f:
        f.write(json.dumps(PERFORMANCE, ensure_ascii=False, indent=4))
    
    return model

def prepare_data(data_path, rule_dict, ent2id, rel2id, embeddings, threshold, noisy_or, split):
    spec_relations = ['focus_of', 'associated_with', 'temporally_related_to']
    data, symptoms_train, diseases_train = load_data(data_path, split)
    X = extract_features(rule_dict, symptoms_train, spec_relations, diseases_train, ent2id, rel2id, embeddings, data, threshold, noisy_or)
    y = [i['disease'] for i in data]
    X=np.asarray(X)
    y=np.asarray(y)

    print(f'X shape:{X.shape}')
    print(f'y shape:{y.shape}')

    return X, y, symptoms_train, diseases_train, data

class NeuroNet(torch.nn.Module):
        def __init__(self, symptom_num, embedding_dim, disease_num, rule_num, rel_embeddings, dise_embeddings,pretrained_weights):
            super(NeuroNet, self).__init__()
            self.rule_num = rule_num # 
            self.symptom_num = symptom_num # 93
            self.embedding_dim  = embedding_dim # 512
            self.disease_num = disease_num # 12
            # self.embeddings = nn.Embedding(self.symptom_num, self.embedding_dim)
            self.embeddings = nn.Embedding.from_pretrained(pretrained_weights, freeze=False) # embedding_dim = 1
            self.rel_embeddings = rel_embeddings # (3, 512)
            self.dise_embeddings = dise_embeddings # (12, 512)
            self.fc = nn.Linear(self.disease_num * (len(rel_embeddings) * self.symptom_num + self.rule_num), self.disease_num)
            self.sigmoid = nn.Sigmoid()

        def forward(self, inputs, rule_features):
            embeds = self.embeddings.weight # (93, 512)
            active_embeds = torch.zeros_like(embeds) # (93, 512)

            for idx in torch.nonzero(inputs).squeeze(1): # activate with input symptoms
                active_embeds[idx] = embeds[idx]

            embeds_add_rel = torch.unsqueeze(active_embeds, 0) # (1, 93, 512)
            embeds_add_rel = embeds_add_rel.repeat((len(self.rel_embeddings), 1, 1)) # (3, 93, 512)

            for idx, rel_emb in enumerate(self.rel_embeddings):
                embeds_add_rel[idx] += rel_emb

            embeds_add_rel = embeds_add_rel.reshape((-1, embeds_add_rel.shape[-1])) # [3 * 93, 512]

            cos = nn.CosineSimilarity(dim=1, eps=1e-6)
            kg_features = torch.zeros((self.dise_embeddings.shape[0], embeds_add_rel.shape[0])) 
            for idx, dise_embed in enumerate(self.dise_embeddings):
                tmp = cos(embeds_add_rel, dise_embed)
                kg_features[idx] = tmp

            # print(kg_features.shape) # (12, 279)
            # print(rule_features.shape) # (12, 182)

            hidden = torch.cat((kg_features, rule_features), dim=1) # (12, 461)
            output = self.fc(hidden.view((1, -1))) 
            output = self.sigmoid(output) # (12, )
            return output

def main():
    import time
    args = parser.parse_args()
    start_time = time.time()

    CogKG_path = '/home/weizhepei/workspace/CogKG_english/'
    data_path = CogKG_path + 'data/diagnose/aligned/'
    rule_path = CogKG_path + 'data/rule/disease_rule/'
    KG_path = CogKG_path + "data/KG/"

    rule_dict = load_rules(rule_path)
    ent2id, id2ent, rel2id, id2rel, embeddings = load_KG(KG_path, embed_version=None)
    id2embed_ent, id2embed_rel = embeddings.solver.values()

    # processed_features = prepare_data(data_path, rule_dict, ent2id, rel2id, embeddings, threshold=None, noisy_or = False, split='train')
    # np.save('./features_train.npy', processed_features)

    X_train, y_train, symptoms, diseases, data = np.load('features_train.npy', allow_pickle=True)

    symptoms2id = {j:i for i,j in symptoms.items()}
    diseases2id = {j:i for i,j in diseases.items()}

    symptom_embeddings = []
    for k,v in symptoms.items(): # id -> symptom
        symptom_embeddings.append(id2embed_ent[ent2id[v]])
    symptom_embeddings = np.asarray(symptom_embeddings)

    dise_embeddings = []
    for k,v in diseases.items():
        dise_embeddings.append(id2embed_ent[ent2id[v]])
    dise_embeddings = np.asarray(dise_embeddings)

    rel_embeddings = []
    for rel in ['focus_of', 'associated_with', 'temporally_related_to']:
        rel_embeddings.append(id2embed_rel[rel2id[rel]])
    rel_embeddings = np.asarray(rel_embeddings)

    print(f'symptom_embeddings: {len(symptom_embeddings)}, {len(symptom_embeddings[0])}')
    print(f'dise_embeddings: {len(dise_embeddings)}, {len(dise_embeddings[0])}')
    print(f'rel_embeddings: {len(rel_embeddings)}, {len(rel_embeddings[0])}')

    assert len(data) == len(X_train)
    TRAIN_DATA = []
    for i in range(len(data)):
        input = [0] * len(symptoms) 
        for k,v in data[i]['symptoms'].items():
            if v == 'True':
                input[int(symptoms2id[k])] = 1
    
        output = [int(diseases2id[data[i]['disease']])]
        if np.sum(input) == 0:
            # print(data[i])
            continue

        TRAIN_DATA.append((input, output, X_train[i][:,:len(rule_dict)]))

    print(len(TRAIN_DATA))

    # X_valid, y_valid, group = prepare_data(data_path, rule_dict, ent2id, rel2id, embeddings, threshold=None, noisy_or=False, split='valid')

    # processed_features_test = prepare_data(data_path, rule_dict, ent2id, rel2id, embeddings, threshold=None, noisy_or=False, split='test')
    # np.save('./features_test.npy', processed_features_test)

    X_test, y_test, _, _, data_test = np.load('features_test.npy', allow_pickle=True)

    assert len(data_test) == len(X_test)
    TEST_DATA = []
    for i in range(len(data_test)):
        input = [0] * len(symptoms) 
        for k,v in data[i]['symptoms'].items():
            if v == 'True':
                input[int(symptoms2id[k])] = 1
    
        output = [int(diseases2id[data[i]['disease']])]
        if np.sum(input) == 0:
            # print(data[i])
            continue

        TEST_DATA.append((input, output, X_test[i][:,:len(rule_dict)]))

    print(len(TEST_DATA))

    model = NeuroNet(len(symptom_embeddings), len(symptom_embeddings[0]), len(dise_embeddings), len(rule_dict), torch.FloatTensor(rel_embeddings), torch.FloatTensor(dise_embeddings), pretrained_weights=torch.FloatTensor(symptom_embeddings))

    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name)

    loss_function = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
  
    # Run the training loop
    for epoch in range(0, 50): # 5 epochs at maximum
        print(f'Starting epoch {epoch+1}')
        # Set current loss value
        current_loss = 0.0
        
        # Iterate over the DataLoader for training data
        for i, data in enumerate(TRAIN_DATA, 0):
            # print(f'sample-{i}')
            # Get inputs
            inputs, targets, rule_features= data
            inputs = torch.tensor(inputs)
            targets = torch.tensor(targets)
            rule_features = torch.FloatTensor(rule_features)

            # Zero the gradients
            optimizer.zero_grad()
            
            # Perform forward pass
            outputs = model(inputs, rule_features)
            
            # Compute loss
            loss = loss_function(outputs, targets)
            
            # Perform backward pass
            loss.backward()
            
            # Perform optimization
            optimizer.step()
            
            # Print statistics
            current_loss += loss.item()
            if i % 500 == 499:
                print('Loss after samples %5d: %.3f' %
                        (i + 1, current_loss / 500))
                current_loss = 0.0

    # Process is complete.
    print('Training process has finished.')

    end_time = time.time()
    print(f'time cost:{(end_time - start_time) / 3600.0} hrs')

    test_model(model, TRAIN_DATA)
    test_model(model, TEST_DATA)

if __name__ == '__main__':
    main()