import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import json
import numpy as np
from utils import *
import random
from tqdm import tqdm
from sklearn.linear_model import LogisticRegression
from interpret import show
from interpret.glassbox import ExplainableBoostingClassifier

from group_lasso import LogisticGroupLasso
from sklearn.pipeline import Pipeline

from sklearn.metrics import f1_score, precision_score, recall_score

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--setting', choices=['lr', 'grpls', 'ebm'])


def get_rank_score(clf, X, y):
    # print(f'decision_function:{clf.decision_function(X)}')
    cls_prob = clf.predict_proba(X)
    assert len(cls_prob) == len(y)

    mrr_list, hits_1_cnt, hits_2_cnt = [], 0, 0
    for i in range(len(cls_prob)):
        goals_dict = {clf.classes_[id]:prob for id,prob in enumerate(cls_prob[i])}
        target = goals_dict[y[i]]

        less_cnt = 0
        for k,v in goals_dict.items():
            if k != y[i] and v < target:
                less_cnt += 1
        rank = len(goals_dict) - less_cnt # worst case 
        assert rank >= 1 and rank <= len(goals_dict)

        if rank == 1:
            hits_1, hits_2 = 1, 1
        elif rank == 2:
            hits_1, hits_2 = 0, 1
        else:
            hits_1, hits_2 = 0, 0

        mrr_list.append(1.0/rank)
        hits_1_cnt += hits_1
        hits_2_cnt += hits_2

    return np.mean(mrr_list), 1.0*hits_1_cnt/len(y), 1.0*hits_2_cnt/len(y)

def check_triple(in_no, V_knw, id2embed_ent, id2embed_rel, id2rel, id2ent, noisy_or, threshold=None, show_infer_step=False, incre=False):
    V_knw_tmp = V_knw.copy()
    in_no_prob = []
    for knw in V_knw_tmp:
        if V_knw_tmp[knw] > 0:
            for idx_rel, vec_rel in enumerate(id2embed_rel):
                if incre and id2rel[idx_rel] not in ['并发症', '病因', '相关疾病']:
                    continue
                elif id2rel[idx_rel] in ['发病部位', '所属科室']:
                    continue
                else:
                    vec_sub = id2embed_ent[knw]
                    # vec_rel = id2embed_rel[rel]
                    vec1 = np.add(vec_sub, vec_rel)
                    vec2 = id2embed_ent[in_no]

                    CosSim = float(np.dot(vec1,vec2)/(np.linalg.norm(vec1)*np.linalg.norm(vec2))) * V_knw_tmp[knw]
                    # print(f'dist:{dist}')
                    if threshold is None:
                        in_no_prob.append(CosSim)
                    else:
                        if (CosSim >= threshold):
                            if show_infer_step:
                                print(f'{id2ent[knw]} + {id2rel[idx_rel]} -> {id2ent[in_no]}')
                            # return True, dist 
                            in_no_prob.append(CosSim)

    if len(in_no_prob) > 0:
        return True, sum_prob(in_no_prob) if noisy_or else max(in_no_prob) # noisy_or:  a + b - a * b , else max(prob_list)
    else:
        return False, 0

def extract_features(rules, id2symp, spec_relations, id2dise, ent2id, rel2id, embeddings, samples, threshold, noisy_or):
    symp2id = {j:i for i,j in id2symp.items()}
    dise2id = {j:i for i,j in id2dise.items()}
    id2rel = {j:i for i,j in rel2id.items()}
    id2ent = {j:i for i,j in ent2id.items()}
    id2embed_ent = embeddings['ent_embeddings.weight']
    id2embed_rel = embeddings['rel_embeddings.weight']
    
    soft_feature_map = np.zeros((len(dise2id), len(id2symp) * len(spec_relations)))

    sr2id, id2sr = {}, {}
    for s in symp2id:
        for r in spec_relations:
            sr2id[(ent2id[s],rel2id[r])] = len(sr2id)
            id2sr[len(id2sr)] = (ent2id[s],rel2id[r])
    
    # print(f'sr{sr2id}')
    
    for dise in dise2id:
        for symp in symp2id:
            for rel in spec_relations:
                vec_sub = id2embed_ent[ent2id[symp]]
                vec_rel = id2embed_rel[rel2id[rel]]
                vec1 = np.add(vec_sub, vec_rel)
                vec2 = id2embed_ent[ent2id[dise]]
                soft_feature_map[int(dise2id[dise])][sr2id[(ent2id[symp], rel2id[rel])]] = float(np.dot(vec1,vec2)/(np.linalg.norm(vec1)*np.linalg.norm(vec2)))
    
    # print(f'soft_features: {soft_feature_map.shape}')

    features = []
    for i in tqdm(samples):
        rule_feature = get_rule_feature(id2dise, dise2id, ent2id, id2rel, id2ent, id2embed_ent, id2embed_rel, rules, i['symptoms'], threshold=threshold, noisy_or=noisy_or)
        # print(f'rule_feature:{rule_feature.shape}')
        soft_feature = activate_soft_feature(soft_feature_map, sr2id, ent2id, rel2id, i['symptoms'], spec_relations)
        features.append(np.hstack((rule_feature, soft_feature)))
    
    # print(f'features:{np.asarray(features).shape}')

    return features

def activate_soft_feature(soft_feature_map, sr2id, ent2id, rel2id, symptoms, spec_relations):
    mask = np.transpose(np.zeros_like(soft_feature_map))
    for s in symptoms:
        if symptoms[s] != 'False':
            for r in spec_relations:
                mask[sr2id[(ent2id[s], rel2id[r])]] = np.ones(mask.shape[-1])
    mask = np.transpose(mask)
    return soft_feature_map * mask

def get_rule_feature(id2dise, dise2id, ent2id, id2rel, id2ent, id2embed_ent, id2embed_rel, rules, symptoms, threshold, noisy_or):
    rule_feature_map = []
    for r in rules:
        symptoms_ids = {ent2id[i]:{'True':1.0, 'False':-1.0}[j] for i,j in symptoms.items()}
        rule_feature = np.zeros(len(id2dise))
        premise = [ent2id[i] for i in rules[r][0]]
        conclusion = rules[r][1][0]
        confidence = rules[r][2]
        rule_feature[int(dise2id[conclusion])] = confidence
    
        coefficient = 1 # early reject = 0, rule fired = 1, otherwise, calculate via link prediction

        early_break = False
        for i in premise:
            if i in symptoms_ids and symptoms_ids[i] == -1:
                early_break = True
                break

        if early_break:
            coefficient = -1
        elif len(set(premise) - set(symptoms_ids.keys())) == 0: # all premises are satisfied
            coefficient = 1
        else: # do link prediction for premise
            for in_no in set(premise) - set(symptoms_ids.keys()):
                isTrue, prob = check_triple(in_no, symptoms_ids, id2embed_ent, id2embed_rel, id2rel, id2ent, threshold=threshold, noisy_or=noisy_or, show_infer_step=False)
                symptoms_ids[in_no] = prob
            coefficient = min([symptoms_ids[i] for i in premise])

        rule_feature_map.append(rule_feature * coefficient)
    return np.transpose(np.asarray(rule_feature_map))


def get_model(model_name, group=None, group_reg=None, lr=None, c=None):
    if model_name == 'LogisticRegression':
        clf = LogisticRegression(random_state=2021, max_iter=2000, solver='liblinear', penalty='l2', C=c)
    elif model_name == 'ExplainableBoosting':
        clf = ExplainableBoostingClassifier(interactions=0, random_state=2021, learning_rate=lr, n_jobs=32)
    elif model_name == 'LogisticGroupLasso':
        clf = Pipeline(
            memory=None,
            steps=[
                ("variable_selection",
                    LogisticGroupLasso(
                        groups=group,
                        group_reg=group_reg,
                        l1_reg=0,
                        scale_reg="inverse_group_size",
                        supress_warning=True,
                        # subsampling_scheme=1,
                        n_iter=80000 #50000
                    )
                ),
                ("regressor", 
                LogisticRegression(random_state=2021, max_iter=2000, solver='liblinear', C=c)),
                ],
            )
    else:
        clf = None
        exit('Invalid Model Name')
    
    return clf


def train_and_valid(model, X_train, y_train, X_test, y_test, group=None):
    model.fit(X_train, y_train)
    _, hits_1, _ = get_rank_score(model, X_test, y_test)

    return hits_1

def train_and_test(model_name, model, X_train, y_train, X_test, y_test, PARAM_DICT):

    model.fit(X_train, y_train)

    mrr, hits_1, hits_2 = get_rank_score(model, X_test, y_test)

    y_pred = model.predict(X_test)
    macro_f1 = f1_score(y_test, y_pred, average='macro')
    macro_p = precision_score(y_test, y_pred, average='macro')
    macro_r = recall_score(y_test, y_pred, average='macro')

    print(f'macro p: {macro_p}, r: {macro_r}, f1: {macro_f1}\n')


    hits_1_score = hits_1
    hits_2_score = hits_2
    mrr = mrr

    print(f'Hits@1: {hits_1_score}')  
    print(f'Hits@2: {hits_2_score}')  
    print(f'MRR: {mrr}')

    PERFORMANCE = {'macro_p':macro_p, 'macro_r':macro_r, 'macro_f1':macro_f1, 'Hits@1':hits_1_score, 'Hits@2':hits_2_score, 'MRR':mrr}

    PERFORMANCE.update(PARAM_DICT)

    with open(f'../PERFORMANCE_CogRepre_{model_name}.json', 'w', encoding='utf-8') as f:
        f.write(json.dumps(PERFORMANCE, ensure_ascii=False, indent=4))
    
    return model

def prepare_data(data_path, rule_dict, ent2id, rel2id, embeddings, threshold, noisy_or, split):
    spec_relations = ['并发症', '病因', '相关疾病']
    data, symptoms_train, diseases_train = load_data(data_path, split)
    X = extract_features(rule_dict, symptoms_train, spec_relations, diseases_train, ent2id, rel2id, embeddings, data, threshold, noisy_or)
    y = [i['disease'] for i in data]
    X=np.asarray(X)
    y=np.asarray(y)
    diseases_num = X.shape[1]
    X = X.reshape(X.shape[0], -1)
    features_num = X.shape[-1]
    group = np.repeat(list(range(1, int(features_num/diseases_num)+1)), diseases_num)
    return X, y, group

def grid_search_LR(model_name, data_path, KG_path, rule_dict, c_list):
    BEST_SCORE = 0
    BEST_EMBEDDING = 0
    BEST_C = 0
    for ver in range(1000, 11000, 1000):
        ent2id, id2ent, rel2id, id2rel, triples, embeddings = load_KG(KG_path, embed_version=ver)
        X_train, y_train, group = prepare_data(data_path, rule_dict, ent2id, rel2id, embeddings, threshold=None, noisy_or = False, split='train')
        X_valid, y_valid, group = prepare_data(data_path, rule_dict, ent2id, rel2id, embeddings, threshold=None, noisy_or=False, split='valid')
        for c in c_list:
            print(f'\nCogRepre + LR with KG training times {ver}, c {c}...')
            model = get_model(model_name, c=c)
            hist_1 = train_and_valid(model, X_train, y_train, X_valid, y_valid)
            if hist_1 > BEST_SCORE:
                BEST_SCORE = hist_1
                BEST_EMBEDDING = ver
                BEST_C = c
        #         break
        #     break
        # break


    # BEST_EMBEDDING = 2000
    # BEST_C = 0.1

    PARAM_DICT = {'BEST_EMBEDDING':BEST_EMBEDDING, 'BEST_C':BEST_C}
    NOISY_OR = False

    print(f'\nModel {model_name}, with KG Traning Times {BEST_EMBEDDING}, c {BEST_C} ..')
    ent2id, id2ent, rel2id, id2rel, triples, embeddings = load_KG(KG_path, embed_version=BEST_EMBEDDING)
    X_train, y_train, group = prepare_data(data_path, rule_dict, ent2id, rel2id, embeddings, threshold=None, noisy_or= NOISY_OR, split='train')
    X_test, y_test, group = prepare_data(data_path, rule_dict, ent2id, rel2id, embeddings, threshold=None, noisy_or=NOISY_OR, split='test')

    # total_size = len(X_train)
    # total_y_train = y_train
    # for ratio in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    #     flag = True
    #     seed = ratio * 100
    #     while flag == True:
    #         flag = True
    #         label_set = set()
    #         order = list(range(total_size))
    #         random.seed(seed)
    #         random.shuffle(order)
    #         order = order[:int(len(order) * ratio)]
    #         for i in total_y_train[order]:
    #             label_set.add(i)
    #             if len(label_set) == 4:
    #                 flag = False
    #                 break
    #         seed += 1
    #     X = X_train[order]
    #     y = y_train[order]
    #     print(f'ratio:{ratio}')
    #     train_and_test(model, X, y, X_test, y_test)
        
    model = get_model(model_name, c=BEST_C)
    train_and_test(model_name, model, X_train, y_train, X_test, y_test, PARAM_DICT)
    
def grid_search_EBM(model_name, data_path, KG_path, rule_dict, lr_list):
    BEST_SCORE = 0
    BEST_EMBEDDING = 0
    BEST_LR = 0
    for ver in range(1000, 11000, 1000):
        ent2id, id2ent, rel2id, id2rel, triples, embeddings = load_KG(KG_path, embed_version=ver)
        X_train, y_train, group = prepare_data(data_path, rule_dict, ent2id, rel2id, embeddings, threshold=None, noisy_or = False, split='train')
        X_valid, y_valid, group = prepare_data(data_path, rule_dict, ent2id, rel2id, embeddings, threshold=None, noisy_or=False, split='valid')
        for lr in lr_list:
            print(f'\nCogRepre + EBM with KG training times {ver}, lr {lr}...')
            model = get_model(model_name, lr=lr)
            hist_1 = train_and_valid(model, X_train, y_train, X_valid, y_valid)
            if hist_1 > BEST_SCORE:
                BEST_SCORE = hist_1
                BEST_EMBEDDING = ver
                BEST_LR = lr
        #         break
        #     break
        # break
    
    # BEST_EMBEDDING = 1000
    # BEST_LR = 1e-4
    PARAM_DICT = {'BEST_EMBEDDING':BEST_EMBEDDING, 'BEST_LR':BEST_LR}
    NOISY_OR = False

    print(f'\nModel {model_name}, with KG Traning Times {BEST_EMBEDDING}, learning_rate {BEST_LR} ..')
    ent2id, id2ent, rel2id, id2rel, triples, embeddings = load_KG(KG_path, embed_version=BEST_EMBEDDING)
    X_train, y_train, group = prepare_data(data_path, rule_dict, ent2id, rel2id, embeddings, threshold=None, noisy_or= NOISY_OR, split='train')
    X_test, y_test, group = prepare_data(data_path, rule_dict, ent2id, rel2id, embeddings, threshold=None, noisy_or=NOISY_OR, split='test')

    # total_size = len(X_train)
    # total_y_train = y_train
    # for ratio in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    #     flag = True
    #     seed = ratio * 100
    #     while flag == True:
    #         flag = True
    #         label_set = set()
    #         order = list(range(total_size))
    #         random.seed(seed)
    #         random.shuffle(order)
    #         order = order[:int(len(order) * ratio)]
    #         for i in total_y_train[order]:
    #             label_set.add(i)
    #             if len(label_set) == 4:
    #                 flag = False
    #                 break
    #         seed += 1
    #     X = X_train[order]
    #     y = y_train[order]
    #     print(f'ratio:{ratio}')
    #     train_and_test(model, X, y, X_test, y_test)
        
    model = get_model(model_name, lr=BEST_LR)
    train_and_test(model_name, model, X_train, y_train, X_test, y_test, PARAM_DICT)
    
def grid_search_GroupLassoLR(model_name, data_path, KG_path, rule_dict, c_list, group_reg_list):
    BEST_SCORE = 0
    BEST_EMBEDDING = 0
    BEST_C = 0
    BEST_GRP_REG = 0
    # for ver in range(1000, 11000, 1000):
    #     ent2id, id2ent, rel2id, id2rel, triples, embeddings = load_KG(KG_path, embed_version=ver)
    #     print(f'\nPerformance with KG training times {ver} ...')
    #     X_train, y_train, group = prepare_data(data_path, rule_dict, ent2id, rel2id, embeddings, threshold=None, noisy_or = False, split='train')
    #     X_valid, y_valid, group = prepare_data(data_path, rule_dict, ent2id, rel2id, embeddings, threshold=None, noisy_or=False, split='valid')
    #     for c in c_list:
    #         for grp_reg in group_reg_list:
    #             model = get_model(model_name, c=c, group_reg=grp_reg, group=group)
    #             hist_1 = train_and_valid(model, X_train, y_train, X_valid, y_valid)
    #             if hist_1 > BEST_SCORE:
    #                 BEST_SCORE = hist_1
    #                 BEST_EMBEDDING = ver
    #                 BEST_C = c
    #                 BEST_GRP_REG = grp_reg
    #     #             break
    #     #         break
    #     #     break
    #     # break

    BEST_EMBEDDING = 1000
    BEST_C = 1
    BEST_GRP_REG = 1e-5
    PARAM_DICT = {'BEST_EMBEDDING':BEST_EMBEDDING, 'BEST_C':BEST_C, 'BEST_GRP_REG':BEST_GRP_REG}
    NOISY_OR = False

    print(f'\nModel {model_name}, with KG Traning Times {BEST_EMBEDDING}, c {BEST_C} , Group Reg {BEST_GRP_REG}..')
    ent2id, id2ent, rel2id, id2rel, triples, embeddings = load_KG(KG_path, embed_version=BEST_EMBEDDING)
    X_train, y_train, group = prepare_data(data_path, rule_dict, ent2id, rel2id, embeddings, threshold=None, noisy_or= NOISY_OR, split='train')
    X_test, y_test, group = prepare_data(data_path, rule_dict, ent2id, rel2id, embeddings, threshold=None, noisy_or=NOISY_OR, split='test')

    # total_size = len(X_train)
    # total_y_train = y_train
    # for ratio in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
    #     flag = True
    #     seed = ratio * 100
    #     while flag == True:
    #         flag = True
    #         label_set = set()
    #         order = list(range(total_size))
    #         random.seed(seed)
    #         random.shuffle(order)
    #         order = order[:int(len(order) * ratio)]
    #         for i in total_y_train[order]:
    #             label_set.add(i)
    #             if len(label_set) == 4:
    #                 flag = False
    #                 break
    #         seed += 1
    #     X = X_train[order]
    #     y = y_train[order]
    #     print(f'ratio:{ratio}')
    #     train_and_test(model, X, y, X_test, y_test)
        
    model = get_model(model_name, c=BEST_C, group_reg=BEST_GRP_REG, group=group)
    model = train_and_test(model_name, model, X_train, y_train, X_test, y_test, PARAM_DICT)
    
    # Extract from pipeline
    sparsity_mask = model["variable_selection"].sparsity_mask_
    print(f'Selected Group: {model["variable_selection"].chosen_groups_}')
    coef = model["regressor"].coef_
    bias = model["regressor"].intercept_
    np.save('weights.npy', coef)
    np.save('bias.npy', bias)
    np.save('mask.npy', sparsity_mask)
    # Print performance metrics
    print(f'Coef:{coef.shape}')
    print(f"Number variables: {len(sparsity_mask)}")
    print(f"Number of chosen variables: {sparsity_mask.sum()}")

    # ExpBoostingMethod
    # ebm_global = clf.explain_global()
    # show(ebm_global)

def main():
    import time
    args = parser.parse_args()

    CogKG_path = '/home/weizhepei/workspace/CogKG/'
    data_path = CogKG_path + 'data/diagnose/aligned/'
    rule_path = CogKG_path + 'data/rule/disease_rule/'
    KG_path = CogKG_path + "data/KG/miniKG/"

    rule_dict = load_rules(rule_path)

    start_time = time.time()

    if args.setting == 'lr':
        grid_search_LR(model_name='LogisticRegression', data_path = data_path, KG_path = KG_path, rule_dict = rule_dict, c_list=[0.1, 1, 10])
    elif args.setting == 'ebm':    
        grid_search_EBM(model_name='ExplainableBoosting', data_path = data_path, KG_path = KG_path, rule_dict = rule_dict, lr_list=[1e-5, 1e-4, 1e-3, 1e-2])
    elif args.setting == 'grpls':
        # grid_search_GroupLassoLR(model_name='LogisticGroupLasso', data_path = data_path, KG_path = KG_path, rule_dict = rule_dict, c_list=[0.1, 1, 10], group_reg_list=[5e-6, 6e-6, 7e-6, 8e-6, 9e-6, 1e-5, 2e-5, 3e-5, 4e-5, 5e-5])
        grid_search_GroupLassoLR(model_name='LogisticGroupLasso', data_path = data_path, KG_path = KG_path, rule_dict = rule_dict, c_list=[1], group_reg_list=[1e-5])


    end_time = time.time()
    print(f'time cost:{(end_time - start_time) / 3600.0} hrs')
if __name__ == '__main__':
    main()
