import jsonlines
import os
from utils.python_utils import get_apis, get_args, transform, get_connected_subsequences
import numpy as np
import json

# code = """
# GET_person__person_id__images(person_id=[person.id for person in response_GET_tv__tv_id__credits.cast if person.order==1][0])
# """

# get_args(code)

def get_exact_match_seq(gt_code, gen_code):
    apis1 = get_apis(gt_code)
    apis2 = get_apis(gen_code)
    if apis1 == apis2:
        return 1
    else: 
        return 0
    
def get_exact_match_connected_subseq(gt_code, gen_code):
    subseq1 = get_connected_subsequences(gt_code)
    subseq2 = get_connected_subsequences(gen_code)
    if subseq1 == subseq2:
        return 1
    else:
        # print("*"*30)
        # print(subseq1) 
        # print(subseq2)
        # print(gt_code)
        # print("---")
        # print(gen_code) 
        return 0
    
def get_exact_match_full_code(gt_code, gen_code):
    args1 = get_args(gt_code)
    args2 = get_args(gen_code)
    if args1 == args2:
        return 1
    else: 
        return 0
    
def get_exact_match_full_code_required_params(gt_code, gen_code):
    args1 = get_args(gt_code)
    args2 = get_args(gen_code)
    for fn in args1:
        if fn not in args2:
            return 0
        for arg in args1[fn]:
            if arg not in args2[fn]:
                return 0
            if args1[fn][arg] != args2[fn][arg]:
                return 0
    
    return 1
    
def get_exact_match_function_wise(gt_code, gen_code):
    args1 = get_args(gt_code)
    args2 = get_args(gen_code)
    scores = []
    for fn in args1:
        if fn in args2:
            if args1[fn] == args2[fn]:
                scores.append(1)
            else: 
                scores.append(0)
        else:
            # print(fn)
            # print(code2)
            scores.append(0)
    return scores

def get_exact_match_function_wise_required_params(gt_code, gen_code):
    args1 = get_args(gt_code)
    args2 = get_args(gen_code)
    scores = []
    for fn in args1:
        score = 1
        if fn in args2:
            for arg in args1[fn]:
                if arg not in args2[fn]:
                    score = 0
                    break
                if args1[fn][arg] != args2[fn][arg]:
                    score = 0
                    break
        else:
            # print(fn)
            # print(code2)
            score = 0

        scores.append(score)
        
    return scores

def get_exact_match_scores(output_file_path):
    full_code_scores = []
    full_code_scores_req = []
    fn_wise_scores = []
    fn_wise_scores_req = []
    seq_scores = []
    connected_subseq_scores = []
    with jsonlines.open(output_file_path) as reader:
        for i, obj in enumerate(reader):
            gt_code = obj["expected_output"]
            new_gt_code = transform(gt_code)
            gen_code = obj["generated_output"].replace("<|endoftext|>","")
            new_gen_code = transform(gen_code)
            
            fc = get_exact_match_full_code(new_gt_code, new_gen_code)
            fc_req = get_exact_match_full_code_required_params(new_gt_code, new_gen_code)
            fn = get_exact_match_function_wise(new_gt_code, new_gen_code)
            fn_req = get_exact_match_function_wise_required_params(new_gt_code, new_gen_code)
            seq_score = get_exact_match_seq(new_gt_code, new_gen_code)
            connected_subseq_score = get_exact_match_connected_subseq(new_gt_code, new_gen_code)

            full_code_scores.append(fc)
            full_code_scores_req.append(fc_req)
            fn_wise_scores.extend(fn)
            fn_wise_scores_req.extend(fn_req)
            seq_scores.append(seq_score)
            connected_subseq_scores.append(connected_subseq_score)

    return np.mean(full_code_scores), np.mean(full_code_scores_req), np.mean(fn_wise_scores), np.mean(fn_wise_scores_req), np.mean(seq_scores), np.mean(connected_subseq_scores)


# output_file_path = "few_shot_prompting/codellama_output/output.jsonl"
# s1, s2 = get_exact_match_scores(output_file_path)

# print(s1, s2)

models = ["codellama", "deepseek", "granite"]
fnames = ["out"]
scores = {}
for model in models:
    scores[model] = {}
    scores[model]["overall"] = {}
    output_file_path = f"few_shot_prompting/{model}_output/output.jsonl"
    fc, fc_req, fn, fn_req, seq, subseq = get_exact_match_scores(output_file_path)
    scores[model]["overall"]["arg_match_full"] = fc
    scores[model]["overall"]["arg_match_full_required_params"] = fc_req
    scores[model]["overall"]["arg_match_function_wise"] = fn
    scores[model]["overall"]["arg_match_function_wise_required_params"] = fn_req
    scores[model]["overall"]["seq_match_full"] = seq
    scores[model]["overall"]["seq_match_connected_subsequence"] = subseq

    scores[model]["spotify"] = {}
    output_file_path = f"few_shot_prompting/{model}_output/output_spotify.jsonl"
    fc, fc_req, fn, fn_req, seq, subseq = get_exact_match_scores(output_file_path)
    scores[model]["spotify"]["arg_match_full"] = fc
    scores[model]["spotify"]["arg_match_full_required_params"] = fc_req
    scores[model]["spotify"]["arg_match_function_wise"] = fn
    scores[model]["spotify"]["arg_match_function_wise_required_params"] = fn_req
    scores[model]["spotify"]["seq_match_full"] = seq
    scores[model]["spotify"]["seq_match_connected_subsequence"] = subseq

    scores[model]["tmdb"] = {}
    output_file_path = f"few_shot_prompting/{model}_output/output_tmdb.jsonl"
    fc, fc_req, fn, fn_req, seq, subseq = get_exact_match_scores(output_file_path)
    scores[model]["tmdb"]["arg_match_full"] = fc
    scores[model]["tmdb"]["arg_match_full_required_params"] = fc_req
    scores[model]["tmdb"]["arg_match_function_wise"] = fn
    scores[model]["tmdb"]["arg_match_function_wise_required_params"] = fn_req
    scores[model]["tmdb"]["seq_match_full"] = seq
    scores[model]["tmdb"]["seq_match_connected_subsequence"] = subseq
    
with open("few_shot_prompting/scores.json", "w") as out:
    json.dump(scores, out, indent=4)

for model in models:
    scores[model] = {}
    scores[model]["overall"] = {}
    output_file_path = f"few_shot_prompting_react/{model}_output/output.jsonl"
    fc, fc_req, fn, fn_req, seq, subseq = get_exact_match_scores(output_file_path)
    scores[model]["overall"]["arg_match_full"] = fc
    scores[model]["overall"]["arg_match_full_required_params"] = fc_req
    scores[model]["overall"]["arg_match_function_wise"] = fn
    scores[model]["overall"]["arg_match_function_wise_required_params"] = fn_req
    scores[model]["overall"]["seq_match_full"] = seq
    scores[model]["overall"]["seq_match_connected_subsequence"] = subseq

    scores[model]["spotify"] = {}
    output_file_path = f"few_shot_prompting_react/{model}_output/output_spotify.jsonl"
    fc, fc_req, fn, fn_req, seq, subseq = get_exact_match_scores(output_file_path)
    scores[model]["spotify"]["arg_match_full"] = fc
    scores[model]["spotify"]["arg_match_full_required_params"] = fc_req
    scores[model]["spotify"]["arg_match_function_wise"] = fn
    scores[model]["spotify"]["arg_match_function_wise_required_params"] = fn_req
    scores[model]["spotify"]["seq_match_full"] = seq
    scores[model]["spotify"]["seq_match_connected_subsequence"] = subseq

    scores[model]["tmdb"] = {}
    output_file_path = f"few_shot_prompting_react/{model}_output/output_tmdb.jsonl"
    fc, fc_req, fn, fn_req, seq, subseq = get_exact_match_scores(output_file_path)
    scores[model]["tmdb"]["arg_match_full"] = fc
    scores[model]["tmdb"]["arg_match_full_required_params"] = fc_req
    scores[model]["tmdb"]["arg_match_function_wise"] = fn
    scores[model]["tmdb"]["arg_match_function_wise_required_params"] = fn_req
    scores[model]["tmdb"]["seq_match_full"] = seq
    scores[model]["tmdb"]["seq_match_connected_subsequence"] = subseq
    
with open("few_shot_prompting_react/scores.json", "w") as out:
    json.dump(scores, out, indent=4)