from InstructorEmbedding import Instructor as INSTRUCTOR
from rank_bm25 import BM25Okapi
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoTokenizer

import argparse
import json
import os
import pickle as pkl
import random
import shutil

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--score_method', type=str, required=True)
    parser.add_argument('--retriever_dir', type=str, default = None)
    parser.add_argument('--tokenizer_dir', type=str, default = None)
    parser.add_argument('--security_code_dir', type=str, required=True)
    parser.add_argument('--evaluation_dir', type=str, required=True, default = "sven/data_eval/trained")
    parser.add_argument('--output_dir', type=str, required=True, default = "./retriever_trained")

    args = parser.parse_args()

    return args

def load_security_codes(args):
    security_codes = {}
    security_code_files = os.listdir(args.security_code_dir)
    for security_code_file in security_code_files:
        with open(os.path.join(args.security_code_dir, security_code_file), 'r') as file:
            security_codes[security_code_file] = file.read()

    return security_codes

def load_querys(args):
    querys = {}
    cwe_files = os.listdir(args.evaluation_dir)
    for cwe_file in cwe_files:
        scenario_files = os.listdir(os.path.join(args.evaluation_dir, cwe_file))
        for scenario_file in scenario_files:
            with open(os.path.join(args.evaluation_dir, cwe_file, scenario_file, "info.json"), 'r') as file:
                info_json = json.load(file)
            lang = info_json['language']

            with open(os.path.join(args.evaluation_dir, cwe_file, scenario_file, "file_context." + lang), 'r') as file:
                file_context = file.read()

            with open(os.path.join(args.evaluation_dir, cwe_file, scenario_file, "func_context." + lang), 'r') as file:
                func_context = file.read()

            querys[cwe_file + "-" + scenario_file] = file_context + func_context
    
    return querys

def random_retriever(querys, security_codes):
    scores = {}
    for query_name in querys:
        scores[query_name] = {}
        for security_code_name in security_codes:
            scores[query_name][security_code_name] = random.random()

        scores[query_name] = {key: scores[query_name][key] for key in sorted(scores[query_name].keys(), key=scores[query_name].get)}

    return scores

def BM25_retriever(args, querys, security_codes):
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir, model_max_length=4096)
    tokenizer.pad_token_id = tokenizer.eos_token_id

    corpus = []
    idx_to_name = []
    for security_code_name in security_codes:
        corpus.append(tokenizer(security_codes[security_code_name])['input_ids'])
        idx_to_name.append(security_code_name)

    bm25 = BM25Okapi(corpus)

    scores = {}
    for query_name in querys:
        query = querys[query_name]
        scores[query_name] = {}
        query_idx = tokenizer(query)['input_ids']
        score = bm25.get_scores(query_idx)
        for idx, temp_score in enumerate(score):
            scores[query_name][idx_to_name[idx]] = temp_score

        scores[query_name] = {key: scores[query_name][key] for key in sorted(scores[query_name].keys(), key=scores[query_name].get)}
        
    return scores

def INSTRUCTOR_retriever(args, querys, security_codes):
    model = INSTRUCTOR(args.retriever_dir)
    scores = {}

    for query_name in querys:
        print("query_name: ", query_name)
        query = querys[query_name]
        query = [['Represent a coding problem description: ', query]]
        embedding_query = model.encode(query)
        scores[query_name] = {}
        for security_code_name in security_codes:
            security_code = security_codes[security_code_name]
            security_code = [['Represent the code: ', security_code]]
            embedding_security_code = model.encode(security_code)

            similarity = cosine_similarity(embedding_query, embedding_security_code)
            scores[query_name][security_code_name] = similarity[0][0]

        scores[query_name] = {key: scores[query_name][key] for key in sorted(scores[query_name].keys(), key=scores[query_name].get, reverse = True)}

    return scores

def generate_evaluation_data(args, scores, security_codes):
    if os.path.exists(args.output_dir):
        print("the output path \"{}\" has existed".format(args.output_dir))
        exit()

    shutil.copytree(args.evaluation_dir, args.output_dir)

    cwe_files = os.listdir(args.output_dir)
    for cwe_file in cwe_files:
        scenario_files = os.listdir(os.path.join(args.output_dir, cwe_file))
        for scenario_file in scenario_files:
            with open(os.path.join(args.output_dir, cwe_file, scenario_file, "info.json"), 'r') as file:
                info_json = json.load(file)
            lang = info_json['language']

            shutil.copyfile(os.path.join(args.output_dir, cwe_file, scenario_file, "file_context." + lang), os.path.join(args.output_dir, cwe_file, scenario_file, "ori_file_context." + lang))

            with open(os.path.join(args.output_dir, cwe_file, scenario_file, "file_context." + lang), 'r') as file:
                file_context = file.read()

            security_code = ""
            if len(scores) != 0:
                for security_code_name in scores[cwe_file + "-" + scenario_file]:
                    security_code = security_codes[security_code_name]
                    if "py" == lang:
                        security_code = security_code.replace('"""', "'''")
                        file_context = '"""\n' + security_code + '\n"""\n' + file_context
                    elif "c" == lang:
                        if ("# else" in security_code):
                            print("the security code has '# else', it may be affect the result: ", security_code)
                        security_code = security_code.replace("# else", "## else")
                        file_context = '#if 0\n' + security_code + '\n#endif\n' + file_context
                    break
            
            with open(os.path.join(args.output_dir, cwe_file, scenario_file, "file_context." + lang), 'w') as file:
                file.write(file_context)
            

if __name__ == "__main__":
    args = get_args()

    security_codes = load_security_codes(args)
    querys = load_querys(args)
    
    if "random" == args.score_method:
        scores = random_retriever(querys, security_codes)
    elif "BM25" == args.score_method:
        scores = BM25_retriever(args, querys, security_codes)
    elif "INSTRUCTOR" == args.score_method:
        scores = INSTRUCTOR_retriever(args, querys, security_codes)
    elif "None" == args.score_method:
        scores = {}

    with open(args.output_dir + "_scores.pkl", 'wb') as file:
        pkl.dump(scores, file)

    generate_evaluation_data(args, scores, security_codes)
