import json
import os
import random
import re
import sys
import torch
from tqdm import tqdm

sys.path.append('./visual_argument_experiments/src/')
print(sys.path)
from visarg.others.utils import parse_annotation, parse_reasoning_steps, resize_image
from visarg.others.clip import load_clip, clip_encode_docs, clip_hard_neg
from visarg.models.colbert import RAG_searcher
from visarg.models.llava import load_llava, llava
# from visarg.models.llavanext import load_llavanext, llavanext
# from models.InstructBLIP import load_instructblip, instructblip
# from models.ofa import load_ofa, ofa
# from visarg.models.QwenVL import load_qwenVL, qwenVL
# from models.unifiedio2 import load_uio2, unifiedio2
# from visarg.models.otter import load_otter, otter
# from visarg.models.idefics import load_idefics, idefics2
# from visarg.models.cogvlm import load_cogvlm, cogvlm
from visarg.others.prompts import new_visual_premise_prompt, new_visual_premise_prompt_gpt

random.seed(1)
VLMS = [
  'llava', 
  # 'llavanext', 
  # 'instructblip',
  # 'ofa',
  # 'qwenVL_CHAT',
  # 'uio2',
  # 'idefics2',
  # "otter",
  # "tttt",
  # 'cogvlm'
  # 'lisa',
  # 'groma'
  ]

def parse_premises():
  premises = {}
  global_vps = []
  global_cps = []
  
  with open('./dataset/cartoon_storage.json') as ctn, open('./dataset/social_storage.json') as soc, open('./dataset/ads_storage.json') as ads, open('./dataset/new_cartoon_storage.json') as new_ctn:
    cartoon_data = json.load(ctn)
    ads_data = json.load(ads)
    soc_data = json.load(soc)
    new_ctn_data = json.load(new_ctn)

    
    for i, key in tqdm(enumerate(cartoon_data.keys()), desc='visual_grounding'):
        (vps, cps, c, rs, ctg) = parse_annotation(cartoon_data[key]['annotation'])
        image_path = os.path.join('./dataset/images', cartoon_data[key]['image_url'].split('/')[-1])

        for vp in vps:
          global_vps.append(vp)
        
        for cp in cps:
          global_cps.append(cp)

        premises[image_path] = {
          'goldset_vps' : vps,
          'goldset_cps' : cps,
          'conclusion' : c,
          'reasoning_steps': rs
        }
    for key in ads_data.keys():
        (vps, cps, c, rs, ctg) = parse_annotation(ads_data[key]['annotation'])
        image_path = os.path.join('./dataset/images', ads_data[key]['image_url'].split('/')[-1])

        for vp in vps:
          global_vps.append(vp)
        for cp in cps:
          global_cps.append(cp)
        
        premises[image_path] = {
          'goldset_vps' : vps,
          'goldset_cps' : cps,
          'conclusion' : c,
          'reasoning_steps': rs
        }
    for key in soc_data.keys():
        (vps, cps, c, rs, ctg) = parse_annotation(soc_data[key]['annotation'])
        image_path = os.path.join('./dataset/images', soc_data[key]['image_url'].split('/')[-1])

        for vp in vps:
          global_vps.append(vp)
        for cp in cps:
          global_cps.append(cp)
        premises[image_path] = {
          'goldset_vps': vps,
          'goldset_cps': cps,
          'conclusion': c,
          'reasoning_steps': rs
        }

    for key in new_ctn_data.keys():
        (vps, cps, c, rs, ctg) = parse_annotation(new_ctn_data[key]['annotation'])
        image_path = os.path.join('./dataset/images', new_ctn_data[key]['image_url'].split('/')[-1])

        for vp in vps:
          global_vps.append(vp)
        for cp in cps:
          global_cps.append(cp)
        premises[image_path] = {
          'goldset_vps': vps,
          'goldset_cps': cps,
          'conclusion': c,
          'reasoning_steps': rs
        }

    # print(len(global_vps))
    # print(len(global_cps))
    vp_RAG = RAG_searcher("vp_index", global_vps, index=True)
    cp_RAG =  RAG_searcher("cp_index", global_cps, index=True)
    clip_model, clip_preprocess = load_clip()
    docs_features = clip_encode_docs(clip_model, global_vps)

  for key in premises.keys():
    goldset_vps = premises[key]['goldset_vps']
    goldset_cps = premises[key]['goldset_cps']
    premises[key]['vp_questions'] = []
    premises[key]['cp_questions'] = []

    hard_negatives_idxes = clip_hard_neg(clip_model, clip_preprocess, key, docs_features).tolist()

    reasoning_steps = premises[key]['reasoning_steps']
    
    goldset_vps_idxs = list(range(len(goldset_vps)))
    goldset_cps_idxs = list(range(len(goldset_cps)))
    global_vps_idxs = list(range(len(global_vps)))
    global_cps_idxs = list(range(len(global_cps)))
    for rs in reasoning_steps:
      vp_idxs, cp_idxs, rs_text = parse_reasoning_steps(rs, reasoning_steps)
      # print('-'*10)
      # print('rs', rs)
      # print('vp', vp_idxs)
      # print('cp', cp_idxs)
      if len(goldset_vps) - len(vp_idxs) > 1:
        
        for vp_idx in vp_idxs:
          random.shuffle(global_vps_idxs)
          random_negatives = []
          for global_vps_idx in global_vps_idxs:
            if global_vps[global_vps_idx] not in goldset_vps:
              random_negatives.append(global_vps[global_vps_idx])
            if len(random_negatives) == 2:
              break
          random.shuffle(goldset_vps_idxs)
          semantic_negatives = []
          for goldset_vps_idx in goldset_vps_idxs:
            if goldset_vps_idx not in vp_idxs:
              semantic_negatives.append(goldset_vps[goldset_vps_idx])
            if len(semantic_negatives) == 2:
              break
          random.shuffle(hard_negatives_idxes)
          clip_negatives = []
          for hard_negatives_idx in hard_negatives_idxes:
            if global_vps[hard_negatives_idx] not in goldset_vps:
              clip_negatives.append(global_vps[hard_negatives_idx])
            if len(clip_negatives) == 2:
              break
          colbert_negatives = []
          colbert_results = vp_RAG.search(goldset_vps[vp_idx])
          for colbert_result in colbert_results:
            if colbert_result['score'] < 25 and len(colbert_negatives) < 2:
              if colbert_result['content'] not in goldset_vps:
                colbert_negatives.append(colbert_result['content'])
            if len(colbert_negatives) == 2:
              break

          vp_question = {
            'vp_idx' : vp_idx,
            'target_vp' : goldset_vps[vp_idx],
            'rs_txt' : rs_text,
            'random_negatives' : random_negatives,
            'semantic_negatives' : semantic_negatives,
            'clip_negatives' : clip_negatives,
            'colbert_negatives' : colbert_negatives
          }
          premises[key]['vp_questions'].append(vp_question)
      if len(goldset_cps) - len(cp_idxs) > 1:
        for cp_idx in cp_idxs:
          random.shuffle(global_cps_idxs)
          random_negatives = []
          for global_cps_idx in global_cps_idxs:
            if global_cps[global_cps_idx] not in goldset_cps:
              random_negatives.append(global_cps[global_cps_idx])
            if len(random_negatives) == 2:
              break
          random.shuffle(goldset_cps_idxs)
          semantic_negatives = []
          for goldset_cps_idx in goldset_cps_idxs:
            if goldset_cps_idx not in cp_idxs:
              semantic_negatives.append(goldset_cps[goldset_cps_idx])
            if len(semantic_negatives) == 2:
              break
          colbert_negatives = []
          colbert_results = cp_RAG.search(goldset_cps[cp_idx])
          for colbert_result in colbert_results:
            if colbert_result['score'] < 25 and len(colbert_negatives) < 2:
              if colbert_result['content'] not in goldset_cps:
                colbert_negatives.append(colbert_result['content'])
            if len(colbert_negatives) == 2:
              break
          
          cp_question = {
            'cp_idx' : cp_idx,
            'target_cp' : goldset_cps[cp_idx],
            'rs_txt' : rs_text,
            'random_negatives' : random_negatives,
            'semantic_negatives' : semantic_negatives,
            'colbert_negatives' : colbert_negatives
          }
          premises[key]['cp_questions'].append(cp_question)

  #       # print(premises[key])
  with open('./out/premise_retrieval/total_premise.json', 'w') as f:
    json.dump(premises, f)

def make_options(target, negatives):
  options = [target] + negatives
  random.shuffle(options)
  answer_idx = options.index(target)

  option_prompt = ""

  for i, option in enumerate(options):
    if i == 0:
      option_prompt += "A) " + option + '\n'
    elif i == 1:
      option_prompt += "B) " + option + '\n'
    else:
      option_prompt += "C) " + option + '\n'

  return option_prompt, options, answer_idx

def premise_retrieval(target="vp", level="easy", vlm="llava", add_conclusion=False, split_id=0):
  print(f'{VLMS[0]} processing in {split_id}')
  with open(f'./out/premise_retrieval/total_premise_{split_id}.json') as ctn:
    premises = json.load(ctn)
    print(len(premises.keys()))

    total_results = {}

    if target == 'vp':
      for key in premises.keys():
        flag = True
        if len(premises[key].get('vp_questions', [])) > 0:
          for question in premises[key]['vp_questions']:
            if len(question["random_negatives"]) == 0 or len(question["semantic_negatives"]) == 0 or len(question["clip_negatives"]) == 0 or len(question["colbert_negatives"]) == 0:
              flag = False
        else:
          flag = False
        
        if flag:
          total_results[key] = {}

    else:
      for key in premises.keys():
        if len(premises[key].get('cp_questions', [])) > 0:
          total_results[key] = {}

    gpt_data = []

    for vlm in VLMS:
      if vlm == 'llava':
        llava_processor, llava_model = load_llava()
      elif vlm == 'llavanext':
        llavanext_processor, llavanext_model = load_llavanext()
      elif vlm == 'instructblip':
        instructblip_processor, instructblip_model = load_instructblip()
      elif vlm == 'ofa':
        ofa_pipe = load_ofa()
      elif vlm == 'qwenVL_CHAT':
        qwenVL_tokenizer, qwenVL_model  = load_qwenVL()
      elif vlm == 'uio2':
        uio2_preprocessor, uio2_model = load_uio2()
      elif vlm == 'idefics2':
        idefics_preprocessor, idefics2_model = load_idefics()
      elif vlm == 'otter':
        otter_preprocessor, otter_model, otter_tokenizer = load_otter()
      elif vlm == 'cogvlm':
        cogvlm_tokenizer, cogvlm_model = load_cogvlm()

      for key in tqdm(total_results.keys()):
        # if key == './dataset/images/ads_116882552808599357.jpg':
        if key != '1212':
          base64_img = resize_image(key)
          if target == 'vp':
            total_results[key][vlm] = []
            for question_idx, question in enumerate(premises[key]['vp_questions']):
              target_vp = question['target_vp']
              conclusion = question["rs_txt"]

              easy_vp_options_prompt, easy_vp_options, easy_answer_idx  = make_options(target_vp, question['random_negatives'])
              hard_clip_vp_options_prompt, hard_clip_vp_options, hard_clip_answer_idx  = make_options(target_vp, question['clip_negatives'])
              hard_colbert_vp_options_prompt, hard_colbert_vp_options, hard_colbert_answer_idx  = make_options(target_vp, question['colbert_negatives'])
              hard_semantic_vp_options_prompt, hard_semantic_vp_options, hard_semantic_answer_idx  = make_options(target_vp, question['semantic_negatives'])
              hard_colbert_clip_vp_options_prompt, hard_colbert_clip_vp_options, hard_colbert_clip_answer_idx = make_options(target_vp, question['colbert_clip_negatives'])

              easy_prompt = new_visual_premise_prompt_gpt(easy_vp_options_prompt, conclusion)
              hard_clip_prompt = new_visual_premise_prompt_gpt(hard_clip_vp_options_prompt, conclusion)
              hard_colbert_prompt = new_visual_premise_prompt_gpt(hard_colbert_vp_options_prompt, conclusion)
              hard_semantic_prompt = new_visual_premise_prompt_gpt(hard_semantic_vp_options_prompt, conclusion)
              hard_colbert_clip_prompt =new_visual_premise_prompt_gpt(hard_colbert_clip_vp_options_prompt, conclusion)
             
              if vlm == 'llava':
                easy_result = llava(llava_processor, llava_model, key, easy_prompt)
                clip_result = llava(llava_processor, llava_model, key, hard_clip_prompt)
                colbert_result = llava(llava_processor, llava_model, key, hard_colbert_prompt)
                semantic_result = llava(llava_processor, llava_model, key, hard_semantic_prompt)
                colbert_clip_result = llava(llava_processor, llava_model, key, hard_colbert_clip_prompt)
              elif vlm == 'llavanext':
                easy_result = llavanext(llavanext_processor, llavanext_model, key, easy_prompt)
                clip_result = llavanext(llavanext_processor, llavanext_model, key, hard_clip_prompt)
                colbert_result = llavanext(llavanext_processor, llavanext_model, key, hard_colbert_prompt)
                semantic_result = llavanext(llavanext_processor, llavanext_model, key, hard_semantic_prompt)
                colbert_clip_result = llavanext(llavanext_processor, llavanext_model, key, hard_colbert_clip_prompt)

              elif vlm == 'instructblip':
                easy_result = instructblip(instructblip_processor, instructblip_model, key, easy_prompt)
                clip_result = instructblip(instructblip_processor, instructblip_model, key, hard_clip_prompt)
                colbert_result = instructblip(instructblip_processor, instructblip_model, key, hard_colbert_prompt)
                semantic_result = instructblip(instructblip_processor, instructblip_model, key, hard_semantic_prompt)
                colbert_clip_result = instructblip(instructblip_processor, instructblip_model, key, hard_colbert_clip_prompt)
              elif vlm == 'ofa':
                easy_result = ofa(ofa_pipe, key, easy_prompt)
                clip_result = ofa(ofa_pipe, key, hard_clip_prompt)
                colbert_result = ofa(ofa_pipe, key, hard_colbert_prompt)
                semantic_result = ofa(ofa_pipe, key, hard_semantic_prompt)
                colbert_clip_result = ofa(ofa_pipe, key, hard_colbert_clip_prompt)
              
              elif vlm == 'qwenVL_CHAT':
                easy_result = qwenVL(qwenVL_tokenizer, qwenVL_model, key, easy_prompt)
                clip_result = qwenVL(qwenVL_tokenizer, qwenVL_model, key, hard_clip_prompt)
                colbert_result = qwenVL(qwenVL_tokenizer, qwenVL_model, key, hard_colbert_prompt)
                semantic_result = qwenVL(qwenVL_tokenizer, qwenVL_model, key, hard_semantic_prompt)
                colbert_clip_result = qwenVL(qwenVL_tokenizer, qwenVL_model, key, hard_colbert_clip_prompt)

              
              elif vlm == 'uio2':
                easy_result = unifiedio2(uio2_preprocessor, uio2_model, key, easy_prompt)
                clip_result = unifiedio2(uio2_preprocessor, uio2_model, key, hard_clip_prompt)
                colbert_result = unifiedio2(uio2_preprocessor, uio2_model, key, hard_colbert_prompt)
                semantic_result = unifiedio2(uio2_preprocessor, uio2_model, key, hard_colbert_clip_prompt)
                colbert_clip_result = unifiedio2(uio2_preprocessor, uio2_model, key, hard_colbert_clip_prompt)

              elif vlm == 'idefics2':
                easy_result = idefics2(idefics_preprocessor, idefics2_model, key, easy_prompt)
                clip_result = idefics2(idefics_preprocessor, idefics2_model, key, hard_clip_prompt)
                colbert_result = idefics2(idefics_preprocessor, idefics2_model, key, hard_colbert_prompt)
                semantic_result = idefics2(idefics_preprocessor, idefics2_model, key, hard_semantic_prompt)
                colbert_clip_result = idefics2(idefics_preprocessor, idefics2_model, key, hard_colbert_clip_prompt)

              elif vlm == 'otter':
                easy_result = otter(otter_preprocessor, otter_model, otter_tokenizer, key, easy_prompt)
                clip_result = otter(otter_preprocessor, otter_model, otter_tokenizer, key, hard_clip_prompt)
                colbert_result = otter(otter_preprocessor, otter_model, otter_tokenizer, key, hard_colbert_prompt)
                semantic_result = otter(otter_preprocessor, otter_model, otter_tokenizer, key, hard_semantic_prompt)
                colbert_clip_result = otter(otter_preprocessor, otter_model, otter_tokenizer, key, hard_colbert_clip_prompt)
              
              elif vlm == 'cogvlm':
                easy_result = cogvlm(cogvlm_tokenizer, cogvlm_model, key, easy_prompt)
                clip_result = cogvlm(cogvlm_tokenizer, cogvlm_model, key, hard_clip_prompt)
                colbert_result = cogvlm(cogvlm_tokenizer, cogvlm_model, key, hard_colbert_prompt)
                semantic_result = cogvlm(cogvlm_tokenizer, cogvlm_model, key, hard_semantic_prompt)
                colbert_clip_result = cogvlm(cogvlm_tokenizer, cogvlm_model, key, hard_colbert_clip_prompt)
                
                

              total_results[key][vlm].append({
                'easy': easy_result,
                'clip_result': clip_result,
                'colbert_result': colbert_result,
                'semantic_result': semantic_result,
                'colbert_clip_result': colbert_clip_result,
                'easy_answer': easy_answer_idx,
                'hard_clip_answer': hard_clip_answer_idx,
                'hard_colbert_answer': hard_colbert_answer_idx,
                'hard_semantic_answer': hard_semantic_answer_idx,
                'hard_colbert_clip_answer': hard_colbert_clip_answer_idx,
                'easy_vp_options': easy_vp_options,
                'hard_clip_vp_options': hard_clip_vp_options,
                'hard_colbert_vp_options': hard_colbert_vp_options,
                'hard_semantic_vp_options': hard_semantic_vp_options,
                'hard_colbert_clip_vp_options': hard_colbert_clip_vp_options
              })
                  
      if vlm == 'llava':
        del llava_model
        torch.cuda.empty_cache()
      elif vlm == 'llavanext':
        del llavanext_model
        torch.cuda.empty_cache()
      elif vlm == 'instructblip':
        del instructblip_model
        torch.cuda.empty_cache()
      elif vlm == 'ofa':
        del ofa_pipe
        torch.cuda.empty_cache()
      elif vlm == 'qwenVL':
        del qwenVL_tokenizer
        del qwenVL_model
        torch.cuda.empty_cache()
      elif vlm == 'uio2':
        del uio2_model
        torch.cuda.empty_cache()
      elif vlm == 'idefics2':
        del idefics2_model
        torch.cuda.empty_cache()
      elif vlm == 'otter':
        del otter_model
        del otter_tokenizer
        torch.cuda.empty_cache()
      elif vlm == 'cogvlm':
        del cogvlm_model
        torch.cuda.empty_cache()
      
      
      print('writing')
      
      with open(f'./out/premise_retrieval/{vlm}_{split_id}.json', 'w') as f:
        json.dump(total_results, f)
    
    # print(len(gpt_data))
    # with open('./gpt_batch_premise_retrieval.jsonl', 'w') as gpt_f:
    #   for g_d in gpt_data:
    #     json_line = json.dumps({
    #       "custom_id": g_d["custom_id"], 
    #       "method": "POST", 
    #       "url": "/v1/chat/completions", 
    #       "body": {
    #         "model": "gpt-4o", 
    #         "messages": [
    #           {
    #             "role": "user",
    #             "content": [
    #               {
    #                 "type": "text",
    #                 "text": g_d["user_message"]
    #               },{
    #                 "type": "image_url",
    #                 "image_url": {
    #                   "url": f"data:image/jpeg;base64,{g_d['base64_image']}",
    #                   "detail": "low"
    #                 }
    #               }
    #             ]
    #           }],
    #         "max_tokens": 20
    #       }})
    #     gpt_f.write(json_line + '\n')


       
    
def parse_result(raw_result):
  if len(raw_result)<5:
    matches = re.findall(r'(\d+|[A-Z])+\)?', raw_result)
  else:
    matches = re.findall(r'(\d+|[A-Z]+|[a-z])+\)', raw_result)

  if len(matches) != 1:
    return -1
  else:
    match = matches[0].split(')')[0]
    if match == 'A' or match == '1':
      return 0
    elif match == 'B' or match == '2':
      return 1
    elif match == 'C' or match == '3':
      return 2
    else:
      return -1
  

  

def get_score(target="vp"):
  with open(f'./out/premise_retrieval/qwenVL_CHAT.json') as f:
    data = json.load(f)

    for vlm in VLMS:
      easy_scores = []
      hard_clip_scores = []
      hard_colbert_scores = []
      hard_semantic_scores = []

      gpt_easy_scores = []
      gpt_hard_clip_scores = []
      gpt_hard_colbert_scores = []
      gpt_hard_semantic_scores = []

      ads_easy_scores = []
      ads_hard_clip_scores = []
      ads_hard_colbert_scores = []
      ads_hard_semantic_scores = []

      ads_gpt_easy_scores = []
      ads_gpt_hard_clip_scores = []
      ads_gpt_hard_colbert_scores = []
      ads_gpt_hard_semantic_scores = []

      cartoon_easy_scores = []
      cartoon_hard_clip_scores = []
      cartoon_hard_colbert_scores = []
      cartoon_hard_semantic_scores = []

      cartoon_gpt_easy_scores = []
      cartoon_gpt_hard_clip_scores = []
      cartoon_gpt_hard_colbert_scores = []
      cartoon_gpt_hard_semantic_scores = []
      for key in data.keys():
        # if key == './dataset/images/cartoons_7svTc8V_.png':
        if key != '121212':
          for i, _ in enumerate(data[key][vlm]):

            raw_easy_result = data[key][vlm][i]['easy']
            raw_hard_clip_result = data[key][vlm][i]['clip_result']
            raw_hard_colbert_result = data[key][vlm][i]['colbert_result']
            raw_hard_semantic_result = data[key][vlm][i]['semantic_result']


            gpt_raw_easy_result = data[key][vlm][i]['gpt_easy']
            gpt_raw_hard_clip_result = data[key][vlm][i]['gpt_clip_result']
            gpt_raw_hard_colbert_result = data[key][vlm][i]['gpt_colbert_result']
            gpt_raw_hard_semantic_result = data[key][vlm][i]['gpt_semantic_result']

            easy_result = parse_result(raw_easy_result)
            hard_clip_result = parse_result(raw_hard_clip_result)
            hard_colbert_result = parse_result(raw_hard_colbert_result)
            hard_semantic_result = parse_result(raw_hard_semantic_result)

            gpt_easy_result = parse_result(gpt_raw_easy_result)
            gpt_hard_clip_result = parse_result(gpt_raw_hard_clip_result)
            gpt_hard_colbert_result = parse_result(gpt_raw_hard_colbert_result)
            gpt_hard_semantic_result = parse_result(gpt_raw_hard_semantic_result)

            easy_answer = data[key][vlm][i]['easy_answer']
            hard_clip_answer = data[key][vlm][i]['hard_clip_answer']
            hard_colbert_answer = data[key][vlm][i]['hard_colbert_answer']
            hard_semantic_answer = data[key][vlm][i]['hard_semantic_answer']

            

            if easy_answer ==  easy_result:
              easy_scores.append(1)
              if 'cartoon' in key:
                cartoon_easy_scores.append(1)
              else:
                ads_easy_scores.append(1)
            else:
              easy_scores.append(0)
              if 'cartoon' in key:
                cartoon_easy_scores.append(0)
              else:
                ads_easy_scores.append(0)
            if hard_clip_answer == hard_clip_result:
              hard_clip_scores.append(1)
              if 'cartoon' in key:
                cartoon_hard_clip_scores.append(1)
              else:
                ads_hard_clip_scores.append(1)
            else:
              hard_clip_scores.append(0)
              if 'cartoon' in key:
                cartoon_hard_clip_scores.append(0)
              else:
                ads_hard_clip_scores.append(0)
            if hard_colbert_answer == hard_colbert_result:
              hard_colbert_scores.append(1)
              if 'cartoon' in key:
                cartoon_hard_colbert_scores.append(1)
              else:
                ads_hard_colbert_scores.append(1)
            else:
              hard_colbert_scores.append(0)
              if 'cartoon' in key:
                cartoon_hard_colbert_scores.append(0)
              else:
                ads_hard_colbert_scores.append(0)
            if hard_semantic_answer == hard_semantic_result:
              hard_semantic_scores.append(1)
              if 'cartoon' in key:
                cartoon_hard_semantic_scores.append(1)
              else:
                ads_hard_semantic_scores.append(1)
            else:
              hard_semantic_scores.append(0)
              if 'cartoon' in key:
                cartoon_hard_semantic_scores.append(0)
              else:
                ads_hard_semantic_scores.append(0)
            if easy_answer ==  gpt_easy_result:
              gpt_easy_scores.append(1)
              if 'cartoon' in key:
                cartoon_gpt_easy_scores.append(1)
              else:
                ads_gpt_easy_scores.append(1)
            else:
              gpt_easy_scores.append(0)
              if 'cartoon' in key:
                cartoon_gpt_easy_scores.append(0)
              else:
                ads_gpt_easy_scores.append(0)
            if hard_clip_answer == gpt_hard_clip_result:
              gpt_hard_clip_scores.append(1)
              if 'cartoon' in key:
                cartoon_gpt_hard_clip_scores.append(1)
              else:
                ads_gpt_hard_clip_scores.append(1)
            else:
              gpt_hard_clip_scores.append(0)
              if 'cartoon' in key:
                cartoon_gpt_hard_clip_scores.append(0)
              else:
                ads_gpt_hard_clip_scores.append(0)
            if hard_colbert_answer == gpt_hard_colbert_result:
              gpt_hard_colbert_scores.append(1)
              if 'cartoon' in key:
                cartoon_gpt_hard_colbert_scores.append(1)
              else:
                ads_gpt_hard_colbert_scores.append(1)
            else:
              gpt_hard_colbert_scores.append(0)
              if 'cartoon' in key:
                cartoon_gpt_hard_colbert_scores.append(0)
              else:
                ads_gpt_hard_colbert_scores.append(0)
            if hard_semantic_answer == gpt_hard_semantic_result:
              gpt_hard_semantic_scores.append(1)
              if 'cartoon' in key:
                cartoon_gpt_hard_semantic_scores.append(1)
              else:
                ads_gpt_hard_semantic_scores.append(1)
            else:
              gpt_hard_semantic_scores.append(0)
              if 'cartoon' in key:
                cartoon_gpt_hard_semantic_scores.append(0)
              else:
                ads_gpt_hard_semantic_scores.append(0)
      print('='*20)
      print('BASEMODEL is ', vlm)
      # print(easy_scores)
      print('easy accuracy', sum(easy_scores)/len(easy_scores))
      # print(hard_clip_scores)
      print('hard clip accuracy', sum(hard_clip_scores)/len(hard_clip_scores))
      # print(hard_colbert_scores)
      print('hard colbert accuracy', sum(hard_colbert_scores)/len(hard_colbert_scores))
      # print(hard_semantic_scores)
      print('hard semantic accuracy', sum(hard_semantic_scores)/len(hard_semantic_scores))

      # print(easy_scores)
      print('gpt easy accuracy', sum(gpt_easy_scores)/len(gpt_easy_scores))
      # print(hard_clip_scores)
      print('gpt hard clip accuracy', sum(gpt_hard_clip_scores)/len(gpt_hard_clip_scores))
      # print(hard_colbert_scores)
      print('gpt hard colbert accuracy', sum(gpt_hard_colbert_scores)/len(gpt_hard_colbert_scores))
      # print(hard_semantic_scores)
      print('gpt hard semantic accuracy', sum(gpt_hard_semantic_scores)/len(gpt_hard_semantic_scores))

      score = {
        'easy accuracy': sum(easy_scores)/len(easy_scores),
        'hard clip accuracy': sum(hard_clip_scores)/len(hard_clip_scores),
        'hard colbert accuracy':  sum(hard_colbert_scores)/len(hard_colbert_scores),
        'hard semantic accuracy': sum(hard_semantic_scores)/len(hard_semantic_scores),
        'gpt easy accuracy': sum(gpt_easy_scores)/len(gpt_easy_scores),
        'gpt hard clip accuracy': sum(gpt_hard_clip_scores)/len(gpt_hard_clip_scores),
        'gpt hard colbert accuracy': sum(gpt_hard_colbert_scores)/len(gpt_hard_colbert_scores),
        'gpt hard semantic accuracy': sum(gpt_hard_semantic_scores)/len(gpt_hard_semantic_scores),
        'cartoon easy accuracy': sum(cartoon_easy_scores)/len(cartoon_easy_scores),
        'cartoon hard clip accuracy': sum(cartoon_hard_clip_scores)/len(cartoon_hard_clip_scores),
        'cartoon hard colbert accuracy':  sum(cartoon_hard_colbert_scores)/len(cartoon_hard_colbert_scores),
        'cartoon hard semantic accuracy': sum(cartoon_hard_semantic_scores)/len(cartoon_hard_semantic_scores),
        'cartoon gpt easy accuracy': sum(cartoon_gpt_easy_scores)/len(cartoon_gpt_easy_scores),
        'cartoon gpt hard clip accuracy': sum(cartoon_gpt_hard_clip_scores)/len(cartoon_gpt_hard_clip_scores),
        'cartoon gpt hard colbert accuracy': sum(cartoon_gpt_hard_colbert_scores)/len(cartoon_gpt_hard_colbert_scores),
        'cartoon gpt hard semantic accuracy': sum(cartoon_gpt_hard_semantic_scores)/len(cartoon_gpt_hard_semantic_scores),
        'ads easy accuracy': sum(ads_easy_scores)/len(ads_easy_scores),
        'ads hard clip accuracy': sum(ads_hard_clip_scores)/len(ads_hard_clip_scores),
        'ads hard colbert accuracy':  sum(ads_hard_colbert_scores)/len(ads_hard_colbert_scores),
        'ads hard semantic accuracy': sum(ads_hard_semantic_scores)/len(ads_hard_semantic_scores),
        'ads gpt easy accuracy': sum(ads_gpt_easy_scores)/len(ads_gpt_easy_scores),
        'ads gpt hard clip accuracy': sum(ads_gpt_hard_clip_scores)/len(ads_gpt_hard_clip_scores),
        'ads gpt hard colbert accuracy': sum(ads_gpt_hard_colbert_scores)/len(ads_gpt_hard_colbert_scores),
        'ads gpt hard semantic accuracy': sum(ads_gpt_hard_semantic_scores)/len(ads_gpt_hard_semantic_scores),
      }
      with open(f'{vlm}_premise_retrieval_score.json', 'w') as f:
        json.dump(score, f)


if __name__ == "__main__":
  # parse_premises()
  premise_retrieval()
  # get_score()