from transformers import AutoTokenizer, AutoModelForCausalLM
import torch, json



BATCH_SIZE = 4
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
HF_TOKEN = '?'
INPUT_PATH = 'path/to/test_ordered.txt'
OUTPUT_PATH = 'path/to/batch4_test_x10_0shot_ent_llama3-70B.txt'
COLLECTION_PATH = 'path/to/gc_x10_typedNE_factual_llama3.json'


model_tag = 'llama3-70B'
input_filename = 'test_ordered.txt'
output_filename = 'batch4_test_x10_0shot_ent_' + model_tag
prompt_type = 'entailment'


torch.cuda.empty_cache()

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    token=HF_TOKEN,
)
model.to('cuda')


def wrap_prompt(original_prm_triple, replaced_hyp_triple, tag):
    prompt = ''
    if tag == 'only_hyp':
        prompt = f'Statement: {replaced_hyp_triple}.\nQuestion: Is this statement true?\nChioces:\nA) Yes\nB) No\nC) Unknown\nAnswer: '
    if tag == 'entailment':
        prompt = f'Question:If {original_prm_triple}, then {replaced_hyp_triple}. Is that true or false?\nChoices:\nA) True\nB) False\nAnswer: '
    if tag == 'factual':
        prompt = f'Question: {replaced_hyp_triple}. Is that true or false? \nA) True\nB) False\nAnswer: '
    if tag == 'testing':
        prompt = f'Question: {replaced_hyp_triple}. Is that true or false? \nA) True\nB) False\nAnswer: '
    return prompt

def get_llama3_response_with_score(instructions, input_prompts, max_token=2):
    if len(instructions) != len(input_prompts):
        print("How")
        exit()

    message_batch = []
    for i, prompt in enumerate(input_prompts):
        messages = [
            {"role": "system", "content": instructions[i]},
            {"role": "user", "content": prompt},
        ]
        message_batch.append(messages)

    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = 'left'
    
    input_ids = tokenizer.apply_chat_template(
        message_batch,
        add_generation_prompt=True,
        return_tensors="pt",
        padding=True,
    ).to(model.device)

    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]

    outputs = model.generate(
        input_ids,
        max_new_tokens=max_token,
        eos_token_id=terminators,
        return_dict_in_generate=True,
        do_sample=False,
        output_scores=True,
        temperature=0.0,
        top_p=0.9,
    )


    # print(len(outputs.sequences), outputs.sequences[0].shape)
    answers = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
    probs = outputs.scores



    return answers, probs[0]


dev_set = []
with open(INPUT_PATH, 'r') as input_file:
    dev_set = input_file.readlines()

collection = json.load(open(COLLECTION_PATH, 'r'))


# output_dir = '/content/drive/MyDrive/cr_data/outputs/' + output_filename + '.txt'
output_file = open(OUTPUT_PATH, 'w')
writing = False

line_count = 0
for line in dev_set[:]:
    line_count += 1

    [hyp, prm, label, _] = line.strip().split('\t')
    prm_triple = prm.split(',')
    hyp_triple = hyp.split(',')

    if line_count % 20 == 0: print(line_count)

    is_reverse = False
    if prm_triple[0].lower() == hyp_triple[2].lower() and prm_triple[2].lower() == hyp_triple[0].lower():
        is_reverse = True


    sub = prm_triple[0]
    obj = prm_triple[2]
    prm_key = '|'.join([sub, prm_triple[1], obj])
    if prm_key not in collection:
        print("Strange key:", prm_key)
        continue
    prm_ctxts = collection[prm_key]

    final_answer = False
    pos_prob_sum = 0.0
    neg_prob_sum = 0.0
    normalized_pos_prob_sum = 0.0
    factual_count = 0
    answer_seq = []

    buffered_prompts = []
    buffered_instrs = []

    for i, ctxt in enumerate(prm_ctxts):
        
        
        original_prm_triple = ' '.join([ctxt['sub_ent'], ctxt['rel'], ctxt['obj_ent']])
        replaced_hyp_triple = ' '.join([ctxt['sub_ent'], hyp_triple[1], ctxt['obj_ent']])
        if is_reverse:
            replaced_hyp_triple = ' '.join([ctxt['obj_ent'], hyp_triple[1], ctxt['sub_ent']])


        det_prompt = wrap_prompt(original_prm_triple, replaced_hyp_triple, prompt_type)
        instr = "Only return one mark A, B or C to answer the question."

        buffered_prompts.append(det_prompt)
        buffered_instrs.append(instr)


        if len(buffered_prompts) >= BATCH_SIZE or (i+1) >= len(prm_ctxts):

            # print(f"dumping {len(buffered_prompts)} lines in a batch.")
            answers, probs = get_llama3_response_with_score(buffered_instrs, buffered_prompts, 2)
            buffered_prompts = []
            buffered_instrs = []


            for prob_of_one_line in probs:

                true_prob = 0.0
                false_prob = 0.0

                for a_token in ['A', 'a', 'True', 'true']:
                    a_token_id = tokenizer.convert_tokens_to_ids(a_token)
                    # print(probs[a_token_id])
                    true_prob += prob_of_one_line[a_token_id]

                for b_token in ['B', 'b', 'False', 'false']:
                    b_token_id = tokenizer.convert_tokens_to_ids(b_token)
                    false_prob += prob_of_one_line[b_token_id]


                true_prob = true_prob.item()
                false_prob = false_prob.item()
                if true_prob > 1000000: true_prob = 1000000
                if false_prob > 1000000: false_prob = 1000000

                answer_seq.append([true_prob, false_prob])


    output_line = '\t'.join([prm, hyp, label, str(answer_seq)]) + '\n'
    output_file.write(output_line)

output_file.close()



