import openai, json
from tqdm import tqdm
import re, os, jsonlines
from wikienv import search_step
import random
# to get proper authentication, make sure to use a valid key that's listed in
# the --api-keys flag. if no flag value is provided, the `api_key` will be ignored.
openai.api_key = "EMPTY"
openai.api_base = "http://localhost:8000/v1"

model = "vicuna-33b-v1.3"

pron = ['He', 'She', 'It', 'They']

rewrite_template = '''Rewrite and expand the sentence, keep the highlighted word.\nQuestion: Singled Out debuted on [MTV]. \nAnswer:\n1.Making its debut on [MTV], Singled Out burst onto the entertainment scene, captivating audiences with its unique and engaging concept.\n2.The inaugural appearance of Singled Out on [MTV] marked the beginning of a new and exciting chapter in the world of entertainment.\n3.[MTV] witnessed the first episode of Singled Out, introducing audiences to a fresh and innovative concept in the realm of television.\n4.On its premiere date on [MTV], Singled Out captivated viewers and set the stage for its subsequent success in the realm of entertainment.\n5.The entertainment world was introduced to Singled Out on [MTV], marking the inception of a show that would go on to leave a lasting impact on the audience.\n\nRewrite and expand the sentence, keep the highlighted word.\nQuestion: {prompt} [{ground_truth}]. \nAnswer:'''
fill_template = '''Fill the blank. Q:Abdullah Ibrahim is well-known for performing __. A:Piano. Q: {si} A:'''

wiki_template = '''{obs} {r}'''
pronoun_selection = '''Q: From 'He', 'She', 'It', 'They', choose the proper pronoun for [Abdullah Ibrahim]. A: He\nQ: From 'He', 'She', 'It', 'They', choose the proper pronoun for {sub}. A:'''

def construct_wiki(requests, outputfile):
    break_point = 0
    if os.path.exists(outputfile):
        with jsonlines.open(outputfile, 'r') as f:
            for line in f:
                break_point += 1
        outputf = jsonlines.open(outputfile, 'a')
    else:
        outputf = jsonlines.open(outputfile, 'w')
    ress = []
    for i in tqdm(range(len(requests))):
        if i < break_point:
            continue
        with_wiki = []
        entity = requests[i]['requested_rewrite']['subject']
        target_true = requests[i]['requested_rewrite']['target_true']['str']
        # # sample another entity
        # j = None
        # while 1:
        #     j = random.choice(range(len(requests)))
        #     if j != i:
        #         break
        # another = requests[j]['requested_rewrite']['subject']
        # obs = search_step(another, target_true)
        obs = search_step(entity, target_true)
        if obs == '':
            obs = '<emptywiki>'
        input = pronoun_selection.replace('{sub}', '['+entity+']')
        theinput = input.split('\n')[1].split("A:")[0]
        # print('****\n',input,'****\n')
        completion = openai.ChatCompletion.create(
            model=model,
            messages=[{"role": "user", "content": input}])
        thep = completion.choices[0].message.content
        thep = thep.replace(theinput, '')
        # print(thep)
        if '\n' in thep:
            thep = thep.split('\n')[0]
        if 'A:' in thep:
            thep = thep.split('A:')[1].strip()
        if i%100 == 0:
            print(entity, thep)
        prompt = requests[i]['requested_rewrite']['prompt'].replace('{}', thep)
        l = wiki_template.replace('{obs}', obs)
        l = l.replace('{r}', prompt)
    
        prompt = requests[i]['requested_rewrite']['prompt'].replace('{}', entity)
        l_orig = wiki_template.replace('{obs}', obs)
        l_orig = l_orig.replace('{r}', prompt)
        with_wiki.append(l)
        ress.append({'case_id':requests[i]['case_id'], 'wiki_pronoun_rewrite': with_wiki})
        outputf.write({'case_id':requests[i]['case_id'], 'wiki_pronoun_rewrite': with_wiki})

    total_outputf = outputfile.split('.json')[0] + '.json'
    with open(total_outputf, 'w') as f:
        json.dump(ress, f, indent=2)


with open('/data/share/xx/EasyEdit/cf/counterfact.json', 'r') as f:
    requests = json.load(f)
outputfile = ''
construct_wiki(requests[:2000], outputfile)