import random; random.seed(42)

import sys; sys.path.append("G:/내 드라이브/[1] CCADD N CBDL/[1] Personal Research/2022_MSR_drug_repositioning/[2] Code/EC_title_recommendation")
from EC_recommend_baseline import BertClsEcTitle

import torch

from EC_recommend_baseline_eval import get_fname_ec_title_model
from EC_recommend_baseline_eval import import_ec_title_model_and_tokenizer, import_eval_dataset
from EC_recommend_baseline_eval import tokenize_ecs_titles, inference_for_titles



def inference_for_title(title: str, top_K: int, ranking_by_topics: bool):
    _, _, repr_ecs_selected_total = \
            inference_for_titles([title], 
                                 ec_title_model, 
                                 tokenizer,
                                 repr_ecs,
                                 input_ids_ec, attention_mask_ec,
                                 ranking_by_topics)
            
    return repr_ecs_selected_total[0][:top_K]


if __name__ == "__main__":
    ##main##
    #get fname_ec_title_model from dir_ec_title_model
    dir_best_model = "C:/Users/Admin/Desktop/EC_title_baseline_MSRA_0712/EC_title_best_models"
    input_type = 'only_title'
    sample_n = 'use_total_positive' # 1000000
    pnr = 1
    Ent = 15
    lr = '1e-05'
    fname_ec_title_model = get_fname_ec_title_model(input_type, sample_n, lr, pnr, Ent)

    #import ec_title_best_model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ec_title_model, tokenizer = import_ec_title_model_and_tokenizer(fname_ec_title_model, device)

    #import evaluation dataset
    ecs_topicn = 300 # [50, 100, 150, 200, 300, 'total_{topicn}']
    category = "total"; CT_num = 100; topNs = [1]
    repr_ecs, _, _, _ = \
        import_eval_dataset(ecs_topicn, category, CT_num, topNs)
    #re-select repr_ecs
    repr_ecs_n = 50000
    random.shuffle(repr_ecs); repr_ecs = repr_ecs[:repr_ecs_n]
    
    #set input_ids_ec for inference
    repr_ecs_only_ec = [ec for ec, _ in repr_ecs]
    input_ids_ec, attention_mask_ec = tokenize_ecs_titles(tokenizer, repr_ecs_only_ec)
                
    #inference for titles
    while True:
        title = input("Enter a title of clinical trial (or press Enter to exit): ")
        if title=="":
            break
        top_K = int(input("Enter the integer for top_K: "))
        top_K_repr_ecs = inference_for_title(title, top_K=top_K, ranking_by_topics=True)
        
        # #topic modeling
        # from bertopic import BERTopic
        # bertopic_model = BERTopic()
        # bertopic_model.fit(top_K_repr_ecs)
        
        #print top_K_repr_ecs
        for index, top_repr_ec in enumerate(top_K_repr_ecs):
            print(f"{index + 1}. {top_repr_ec}")
