import os
import pandas as pd

#------------------------------
def write_langs(file, region_dict, baseline):

    with open(file+".lang_specific.csv", "w") as f:
    
        f.write("Language," + "Geo_Precision," + "Baseline_Precision," + "Geo_Recall," + "Baseline_Recall," + "Geo_FScore," + "Baseline_FScore," + "Samples\n")
        
        languages = sorted(list(region_dict.keys()))
        for language in languages:
        
            f.write(language)
            f.write(",")
            
            f.write(str(region_dict[language]["Precision"]))
            f.write(",")
            
            f.write(str(baseline[language]["Precision"]))
            f.write(",")
            
            f.write(str(region_dict[language]["Recall"]))
            f.write(",")
            
            f.write(str(baseline[language]["Recall"]))
            f.write(",")
            
            f.write(str(region_dict[language]["F-Score"]))
            f.write(",")
            
            f.write(str(baseline[language]["F-Score"]))
            f.write(",")
            
            f.write(str(region_dict[language]["Samples"]))
            f.write("\n")
            
    return
    
#------------------------------
def get_dict(file):

    dict = {}
    
    with open(file, "r") as f:
        for line in f:
        
            line = line.strip().split()
            if len(line) > 1:
                if len(line[0]) == 3:
                    
                    if eval(line[4]) > 100:
                        dict[line[0]] = {}
                        dict[line[0]]["Precision"] = eval(line[1])
                        dict[line[0]]["Recall"] = eval(line[2])
                        dict[line[0]]["F-Score"] = eval(line[3])
                        dict[line[0]]["Samples"] = eval(line[4])

    return dict
#-------------------------------

def get_macro(dict):

    precision = []
    recall = []
    fscore = []
    
    for language in dict:
        precision.append(dict[language]["Precision"])
        recall.append(dict[language]["Recall"])
        fscore.append(dict[language]["F-Score"])

    return sum(precision)/len(precision), sum(recall)/len(recall), sum(fscore)/len(fscore)        

#-------------------------------
results = []
root = os.path.join(".", "Results_OpenLID")

#First get baselines for each language
baseline_full = get_dict(os.path.join(root, "eval_openlid.v1.baseline.full_results.txt"))
baseline_ftz = get_dict(os.path.join(root, "eval_openlid.v1.baseline.ftz_results.txt"))

#Now go through files
for file in os.listdir(root):
    if file.endswith(".txt"):
    
        if "full_results" in file and "baseline" not in file:
            name = "full"
            
        elif "ftz_results" in file and "baseline" not in file:
            name = "compressed"
            
        else:
            name = "ERROR"
            
        #Get the dictionary of language metrics
        region = file.split(".")[2]
        print(region, name, file)
        region_dict = get_dict(os.path.join(root, file))
        
        #Get the macro metrics
        with open(os.path.join(root, file), "r") as f:
            for line in f:
                if "macro avg" in line:
                
                    line = line.strip().split()
                    precision = line[2]
                    recall = line[3]
                    fscore = line[4]
                    support = line[5]
                    
                    #print("Macro", region, name, precision, recall, fscore, support)
                    #print(line)
                    
        #Get the dictionary metrics
        new_precision, new_recall, new_fscore = get_macro(region_dict)
        print(region, new_precision, new_recall, new_fscore)
        
        if name == "full":
            baseline = baseline_full.copy()
        elif name == "compressed":
            baseline = baseline_ftz.copy()
        else:
            baseline = baseline_full.copy()
            
        #Reduce to current_region
        to_pop = []
        for key in baseline:
            if key not in region_dict:
                to_pop.append(key)
        for key in to_pop:
            baseline.pop(key)
            
        baseline_precision, baseline_recall, baseline_fscore = get_macro(baseline)
        results.append([region, name, new_precision, baseline_precision, new_recall, baseline_recall, new_fscore, baseline_fscore, support])
        
        #Save language-specific results
        write_langs(file, region_dict, baseline)
                    
df = pd.DataFrame(results, columns = ["Region", "Type", "Geo Precision", "Baseline Precision", "Geo Recall", "Baseline Recall", "Geo F-Score", "Baseline F-Score", "Samples"])
df = df[df.loc[:,"Type"] != "ERROR"]
df.sort_values(["Type", "Region"], inplace=True)
print(df)

df.to_csv("results.openlid.by_region.csv")