from transformers import AutoTokenizer, AutoModelForCausalLM
import torch


MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
HUGGINGFACE_TOKEN = "?"
GENERATED_ARCHIVE_DIR = '/path/to/gc_x10_typedNE_factual_llama3-70b.json'
LOAD_GENERATION_CACHE = False
TESTSET_DIR = '/path/to/test_ordered.txt'
TYPED_TESTSET_DIR = '/path/to/test_ordered_typed.txt'


torch.cuda.empty_cache()

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HUGGINGFACE_TOKEN)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    token=HUGGINGFACE_TOKEN,
)
model.to('cuda')

def get_llama3_response(instruction, input_prompt, max_token=256):
    messages = [
        {"role": "system", "content": instruction},
        {"role": "user", "content": input_prompt},
    ]

    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)

    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]

    outputs = model.generate(
        input_ids,
        max_new_tokens=max_token,
        eos_token_id=terminators,
        do_sample=False,
        temperature=0.0,
        top_p=0.9,
    )

    response = outputs[0][input_ids.shape[-1]:]
    answer = tokenizer.decode(response, skip_special_tokens=True)

    probs = []

    return answer, probs

import json

torch.cuda.empty_cache()

def generate_gpt_combine_1response(sub, rel, obj, sub_type, obj_type, para_num=10, expect_length=50):
    embody_prompt = f'Write {para_num} facts in the form of "{sub_type}|{rel}|{obj_type}." like {sub}|{rel}|{obj}.'
    instr = "You are a helpful agent who will complete the given text."
    embody_ans, _ = get_llama3_response(instr, embody_prompt, max_token=expect_length * para_num)

    triples_text = embody_ans.strip().split('\n')
    while len(triples_text) != 0 and triples_text[0][:2] != ('1.'):
        triples_text = triples_text[1:]
    if len(triples_text) != para_num:
        print('Strange Length:', len(triples_text))
    for i, item in enumerate(triples_text):
        triples_text[i] = triples_text[i].lstrip(f'{i+1}. ')
    # print(triples_text)
    embodied_triples = [text.lstrip("-").strip() for text in triples_text]

    paragraphs = []
    for triple in embodied_triples:
        paragraphs.append('')

    return [embodied_triples, paragraphs]



collection = dict([])
if LOAD_GENERATION_CACHE:
    with open(GENERATED_ARCHIVE_DIR, 'r') as f:
        collection = json.load(f)

def generate_ctxt_and_add_to_collection(sub, rel, obj, sub_type, obj_type):
    triple_t = sub + '|' + rel + '|' + obj
    if triple_t in collection:
        print('skip')
        return
    [triples, paragraphs] = generate_gpt_combine_1response(sub, rel, obj, sub_type, obj_type)

    contexts = []
    for i, sent in enumerate(paragraphs):
        triple = triples[i].strip()
        parts = triple.lower().strip('.').split(rel.lower(), 1)
        if len(parts) == 2:
            sub_e = parts[0].strip('|').strip()
            obj_e = parts[1].strip('|').strip()
        else:
            parts = triple.lower().strip('.').split('|')
            if len(parts) != 3:
                print('Strange:', rel, triple)
                sub_e = ''
                obj_e = ''
            else:
                sub_e = parts[0].strip()
                obj_e = parts[2].strip()
        contexts.append({'sub_ent':sub_e, 'rel':rel, 'obj_ent':obj_e, 'triple':triple})
    collection[triple_t] = contexts
    return


dev_set = []

with open(TESTSET_DIR, 'r') as input_file:
    dev_set = input_file.readlines()
with open(TYPED_TESTSET_DIR, 'r') as input_file:
    typed_dev_set = input_file.readlines()

for i, line in enumerate(dev_set[:]):
    [hypothesis, premise, gold, _] = line.strip().split('\t')
    [typed_hypothesis, typed_premise, _, _] = typed_dev_set[i].strip().split('\t')

    [sub, rel1, obj] = premise.strip().split(',')
    [sub_type, _, obj_type] = typed_premise.strip().split(',')
    generate_ctxt_and_add_to_collection(sub, rel1, obj, sub_type, obj_type)

    if i % 10 == 0: print(i)


with open(GENERATED_ARCHIVE_DIR, 'w') as output:
    json.dump(collection, output, indent=4)