from InstructorEmbedding import Instructor as INSTRUCTOR
from pathlib import Path
from rank_bm25 import BM25Okapi
from sklearn.metrics.pairwise import cosine_similarity
from sven.human_eval.problem_yaml import Problem
from transformers import AutoTokenizer
from tqdm import tqdm

import argparse
import json
import os
import random
import shutil
import yaml

import math

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--score_method', type=str, required=True)
    parser.add_argument('--retriever_dir', type=str, default = None)
    parser.add_argument('--tokenizer_dir', type=str, default = None)
    parser.add_argument('--security_code_dir', type=str, required=True)
    parser.add_argument('--evaluation_dir', type=str, required=True) # , default = "sven/data_eval/trained")
    parser.add_argument('--output_dir', type=str, required=True, default = "./retriever_trained")

    args = parser.parse_args()

    return args

def load_security_codes(args):
    security_codes = []
    security_code_files = os.listdir(args.security_code_dir)
    for security_code_file in security_code_files:
        with open(os.path.join(args.security_code_dir, security_code_file), 'r') as file:
            security_codes.append([security_code_file, file.read()])

    return security_codes

def load_querys(args):
    querys = {}
    output_dir = Path(args.output_dir)

    problems = list(
        filter(
            lambda f: not f.name.endswith(".results.yaml"),
            sorted(output_dir.glob("*.yaml")),
        )
    )

    for problem_yaml_path in tqdm(problems):
        with problem_yaml_path.open() as f:
            problem = Problem.load(f)
        querys[problem.name] = problem.prompt
    
    return querys 

def random_retriever(querys, security_codes):
    scores = {}
    for query_name in querys:
        scores[query_name] = {}
        for security_code_name, _ in security_codes:
            scores[query_name][security_code_name] = random.random()

        scores[query_name] = {key: scores[query_name][key] for key in sorted(scores[query_name].keys(), key=scores[query_name].get)}

    return scores

def BM25_retriever(args, querys, security_codes):
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir, model_max_length=4096)
    tokenizer.pad_token_id = tokenizer.eos_token_id

    corpus = []
    idx_to_name = []
    for security_code_name, security_code in security_codes:
        corpus.append(tokenizer(security_code)['input_ids'])
        idx_to_name.append(security_code_name)

    bm25 = BM25Okapi(corpus)

    scores = {}
    for query_name in querys:
        query = querys[query_name]
        scores[query_name] = {}
        query_idx = tokenizer(query)['input_ids']
        score = bm25.get_scores(query_idx)
        for idx, temp_score in enumerate(score):
            scores[query_name][idx_to_name[idx]] = temp_score

        scores[query_name] = {key: scores[query_name][key] for key in sorted(scores[query_name].keys(), key=scores[query_name].get)}

    return scores

def INSTRUCTOR_retriever(args, querys, security_codes):
    model = INSTRUCTOR(args.retriever_dir)
    scores = {}

    for query_name in tqdm(querys):
        query = querys[query_name]
        query = [['Represent a coding problem description: ', query]]
        embedding_query = model.encode(query)
        scores[query_name] = {}
        security_code_corpus = []
        for security_code_name, security_code in security_codes:
            security_code_corpus.append(['Represent the code: ', security_code])
        embedding_security_codes = model.encode(security_code_corpus)

        similarities = cosine_similarity(embedding_query, embedding_security_codes)

        for idx, similarity in enumerate(similarities[0]):
            scores[query_name][security_codes[idx][0]] = similarity
        scores[query_name] = {key: scores[query_name][key] for key in sorted(scores[query_name].keys(), key=scores[query_name].get, reverse = True)}

    return scores

def generate_evaluation_data(args, scores, security_codes):
    output_dir = Path(args.output_dir)

    problems = list(
        filter(
            lambda f: not f.name.endswith(".results.yaml"),
            sorted(output_dir.glob("*.yaml")),
        )
    )

    security_codes_dir = {}
    for security_code_name, security_code in security_codes:
        security_codes_dir[security_code_name] = security_code

    for problem_yaml_path in tqdm(problems):
        with problem_yaml_path.open() as f:
            problem = Problem.load(f)
            lang = "py"

            security_code = ""
            if len(scores) != 0:
                for security_code_name in scores[problem.name]:
                    security_code = security_codes_dir[security_code_name]
                    if "py" == lang:
                        security_code = security_code.replace('"""', "'''")
                        problem.prompt = '"""\n' + security_code + '\n"""\n' + problem.prompt
                    break
            
        with problem_yaml_path.open("w") as f:
            f.write(Problem.dump(problem))

if __name__ == "__main__":
    args = get_args()

    if os.path.exists(args.output_dir):
        print("the output path \"{}\" has existed".format(args.output_dir))
        exit()
    shutil.copytree(args.evaluation_dir, args.output_dir)
    
    security_codes = load_security_codes(args)
    querys = load_querys(args)
    
    if "random" == args.score_method:
        scores = random_retriever(querys, security_codes)
    elif "BM25" == args.score_method:
        scores = BM25_retriever(args, querys, security_codes)
    elif "INSTRUCTOR" == args.score_method:
        scores = INSTRUCTOR_retriever(args, querys, security_codes)
    elif "None" == args.score_method:
        scores = {}

    with open(args.output_dir + "_scores.pkl", 'wb') as file:
        pkl.dump(scores, file)

    generate_evaluation_data(args, scores, security_codes)
