from inseq.commands.attribute_context.attribute_context import AttributeContextArgs, attribute_context
import os    
import tqdm
import jsonlines


dir_path_input = './record/'
dir_path_output = './pecora-results-full/'
threshold = None

model_path = '/projects/0/prjs0888/datasets/CORA-models/mGEN_model'

files = os.listdir(dir_path_input)
    
for fname in files:
    if 'neg' in fname: continue
    print('*****************************')
    print(fname)
    idx = 0
    with open(dir_path_input + fname) as f:
        for item in jsonlines.Reader(f):
            print("Current: {} - {}".format(fname, idx))
            #"<Q>: " + item['query'] + " <P>:" + passage
            save_path = dir_path_output + fname.split('.')[0].split('-')[1] + '-' + str(idx) + '.json'
            lm_rag_prompting_example = AttributeContextArgs(
                    model_name_or_path=model_path,
                    input_context_text=item['passage_in_language'],
                    input_current_text=f"<Q>: {item['query']}",
                    output_template="{current}",
                    input_template="{current} <P>:{context}",
                    show_intermediate_outputs=False,
                    attributed_fn="contrast_prob_diff",
                    output_current_text=item['prediction'],
                    context_sensitivity_std_threshold=threshold,
                    save_path=save_path,
                    generation_kwargs={"num_beams": 4, "min_length": 1, "max_length": 20, "early_stopping": False}
                    )
            
            attribute_context(lm_rag_prompting_example)

            idx += 1
