import os    
import tqdm
import jsonlines
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score

topk = 1 # 1000000
strategy = 'original' # original, square, exp
#strategy = 'square'
#strategy = 'exp'

# inter or intra
val = True

def expand(num):
    if strategy == 'original':
        return num
    elif strategy == 'square':
        return num**2
    else:
        return np.exp(num)

dir_path_input = './record-full/'
dir_path_output = './pecora-results-full/'

files = os.listdir(dir_path_input)

# Full dataset
# top1
# mean -0.5SD
#threshold = {'te': 1.3244373356358734, 'bn': 0.5818090328144909, 'ja': 0.7374812658326557, 'fi': 2.6087579173802933, 'ru': 1.3846934058095868}
# mean + 0SD
#threshold = {'te': 3.719334466382861, 'bn': 2.4173193188011646, 'ja': 3.4615943045867605, 'fi': 5.882648429386318, 'ru': 4.833288226388395}
# mean + 1SD
threshold = {'te': 8.509128727876837, 'bn': 6.088339890774511, 'ja': 8.90982038209497, 'fi': 12.430429453398368, 'ru': 11.73047786754601}


for fname in files:
    if 'neg' in fname: continue
    if 'val' in fname or 'train' in fname: continue
    if 'nli' in fname: continue

    lang = fname.split('.')[0].split('-')[1]
    print('*****************************')
    print(lang)
    idx = 0

    pos_ins = []
    neg_ins = []

    accuracy = 0
    accuracy_tot = 0
    
    with open(dir_path_input + fname) as f:
        for item in jsonlines.Reader(f):
            '''
            # Remove yes/no questions
            if item['prediction'] in ["yes", "no"]:
                idx += 1
                continue
            '''
            #print("Current: {} - {}".format(lang, idx))
            #"<Q>: " + item['query'] + " <P>:" + passage
            
            # Get the highest CTI score for each instance
            save_path = dir_path_output + lang + '-' + str(idx) + '.json'
            with open(save_path) as r:
                res_pecora = json.load(r)

            
            if val:
                threshold_CTI = threshold[item['query_language']]
            else:
                threshold_CTI = 1.01 * np.mean(res_pecora["cti_scores"]) + 2.8 * np.std(res_pecora["cti_scores"])
            mark_attri = False
            for i in res_pecora["cti_scores"]:
                if i >= threshold_CTI:
                    mark_attri = True

            score = np.max(res_pecora["cti_scores"])

            if item['ais']:
                pos_ins.append(score)
            else:
                neg_ins.append(score)
                
            if mark_attri == item['ais']:
                accuracy += 1
            accuracy_tot += 1

            idx += 1

    label = np.array([1 for _ in range(len(pos_ins))] + [0 for _ in range(len(neg_ins))])
    predict = np.array(pos_ins+neg_ins)
    roc_auc = roc_auc_score(label, predict)

    print()
    print("Accuracy:    {}/{}={}".format(accuracy, accuracy_tot, accuracy/accuracy_tot))
    print("ROC AUC:     {}".format(roc_auc))
    print()
    print("=============")
