import inseq
import sys
import argparse
import json
import numpy as np

parser = argparse.ArgumentParser()
parser.add_argument("--f", type=str, required=True, help="Filename")
args = parser.parse_args()

with open(args.f, 'r') as f:
    data = json.load(f)

CTI_thres = np.mean(data['cti_scores']) + 1 * np.std(data['cti_scores'])


range_left = 133
range_right = 160
top = 0.05

CTI_weight = 1.0
CTI_thres = np.mean(data['cti_scores']) + CTI_weight * np.std(data['cti_scores'])

print("CTI:")
for i, j in enumerate(data["cti_scores"]):
    #if i < range_left or i > range_right: continue
    if j >= CTI_thres:
        print(str(i)+"**"+data["output_current_tokens"][i]+"**" + str(j))
    else:
        print(str(i)+data['output_current_tokens'][i])

print("CCI:")
for c in data["cci_scores"]:
    if c['cti_score'] < CTI_thres: break
    if c['cti_idx'] < range_left or c['cti_idx'] > range_right: continue
    print("======")
    print(data["output_current"])
    print(c['cti_idx'])
    print(c['cti_token'])
    print("------")
    
    CCI_thres = np.max(c['input_context_scores']) - (np.max(c['input_context_scores'])-np.min(c['input_context_scores'])) * top
    tmp = []
    for i, j in enumerate(c['input_context_scores']):
        if j >= CCI_thres: 
            print(tmp)
            print("**"+data['input_context_tokens'][i]+"** " + str(j))
            tmp = []
        else:
            tmp.append(data['input_context_tokens'][i])
    print(tmp)
