import json
import os

import pandas as pd
import tqdm
from gensim.models import KeyedVectors, Word2Vec
from nltk.tokenize import word_tokenize
from scipy import spatial

from Scripts.Prompting import *

prompt_file = r"Scripts\config_experiment\prompts\info2NL.md"

with open(prompt_file, "r") as f:
    PROMPT = f.read()


def regen(PROMPT, input_file, model):
    with open(input_file, "r") as f:
        data = json.load(f)
    st = ["---", "\n", "End", "`", "query"]
    for d in tqdm.tqdm(data):
        if "regenNL" in d.keys():
            continue
        prompt = PROMPT.replace("{query}", d['text']).replace("{info}", str(d['info'])).replace("{info_updated}", str(d["info_examples"]))
        text = get_llm_response(model, prompt)[0]
        for s in st:
            text = text.split(s)[0].strip()
        d["regenNL"] = text

        with open(input_file, "w") as f:
            json.dump(data, f, indent=4)
    return


def get_scores(text1, text2):
    text1_list = word_tokenize(text1)
    text2_list = word_tokenize(text2)
    sentences = [text1_list, text2_list]
    model = Word2Vec(sentences, min_count=1)

    # model = KeyedVectors.load_word2vec_format('GoogleNews-vectors-negative300.bin', binary=True)
    vector1 = model.wv[text1_list].mean(axis=0)
    vector2 = model.wv[text2_list].mean(axis=0)

    similarity = 1 - spatial.distance.cosine(vector1, vector2)
    return similarity


def scorer(output_file):
    with open(output_file, "r") as f:
        data = json.load(f)
    for d in tqdm.tqdm(data):
        d["similarity"] = get_scores(d["text"], d["regen"])
    df = pd.DataFrame(data)
    df.to_csv(output_file.replace(".json", ".csv"), index=False)
    return


def main():
    iteration = 0
    for model in os.listdir("Responses"):
        if model == "original":
            continue
        if model != "gpt4-turbo":
            continue
        for info_struct in os.listdir(os.path.join("Responses", model)):
            if info_struct not in ["agg_comp_imp_fewshot_examples"]:
                continue
            input_file = rf"Responses\cache\mbpp-insuff-ambiguous_{iteration}.json"
            output_file = rf"Responses\cache\mbpp-insuff-ambiguous_{iteration}_regen.json"
            regen(PROMPT, input_file, output_file, model)
            scorer(output_file)
            print(f"Done with {info_struct}")
