
import json
import os
import random
import re
import sys

# sys.path.append('./visual_argument_experiments/src/models/image2con/minigpt4/')

from tqdm import tqdm

sys.path.append('./visual_argument_experiments/src/')
# print(sys.path)
from visarg.others.utils import parse_annotation
# from models.InstructBLIP import instructblip
# from models.BLIP2 import blip2
# from models.unifiedio2 import unifiedio2
# from models.llavanext import llavanext
# from models.llava import llava
# from models.qwenVL import qwenVL
# from models.ofa import ofa
# from models.Openflamingo import openflamingo
# from models.MiniGPT_4 import minigpt_4

from visarg.models.colbert import RAG_searcher
from visarg.others.clip import clip_encode_docs, clip_hard_neg

from visarg.others.prompts import visual_premise_prompt, visual_premise_prompt1, visual_premise_prompt2, commonsense_premise_prompt

COLBERT_THRESHOLD=25
random.seed(2)


def parse_prmises():
  premises = {}
  
  with open('./dataset/social_storage.json') as ctn, open('./dataset/cartoon_storage.json') as soc, open('./dataset/ads_storage.json') as ads:
    cartoon_data = json.load(ctn)
    ads_data = json.load(ads)
    soc_data = json.load(soc)

    
    global_vps = []
    global_cps = []

    for i, key in tqdm(enumerate(cartoon_data.keys()), desc='visual_grounding'):
      if int(key) != 103 and int(key) != 144 and int(key) != 233 and int(key) != 716 and int(key) != 747 and int(key) != 174 and int(key) != 786:
        (vps, cps, c, rs, ctg) = parse_annotation(cartoon_data[key]['annotation'])
        image_path = os.path.join('./dataset/images', cartoon_data[key]['image_url'].split('/')[-1])
        
        for vp in vps:
          global_vps.append(vp)
        
        for cp in cps:
          global_cps.append(cp)

        premises[image_path] = {
          'goldset_vps' : vps,
          'goldset_cps' : cps,
          'conclusion' : c,
        }
    for key in ads_data.keys():
      (vps, cps, c, rs, ctg) = parse_annotation(ads_data[key]['annotation'])
      image_path = os.path.join('./dataset/images', ads_data[key]['image_url'].split('/')[-1])

      for vp in vps:
        global_vps.append(vp)
      for cp in cps:
        global_cps.append(cp)
      
      premises[image_path] = {
        'goldset_vps' : vps,
        'goldset_cps' : cps,
        'conclusion' : c
      }
    for key in soc_data.keys():
      (vps, cps, c, rs, ctg) = parse_annotation(soc_data[key]['annotation'])
      image_path = os.path.join('./dataset/images', soc_data[key]['image_url'].split('/')[-1])
      for vp in vps:
        global_vps.append(vp)
      for cp in cps:
        global_cps.append(cp)
      premises[image_path] = {
        'goldset_vps': vps,
        'goldset_cps': cps,
        'conclusion': c
      }

    print(len(global_vps))
    print(len(global_cps))
    RAG = RAG_searcher("vp_index", global_vps, index=False)
    
    cp_RAG =  RAG_searcher("cp_index", global_cps, index=False)

    docs_features = clip_encode_docs(global_vps)
    print(len(docs_features))

    for key in premises.keys():
      # if key == './dataset/images/ads_211174972497565.jpg':
      goldset_vps = premises[key]['goldset_vps']
      goldset_cps = premises[key]['goldset_cps']
      premises[key]['negative_vps'] = []
      premises[key]['negative_cps'] = []
      premises[key]['clip_hard_negative'] = []
      premises[key]['colbert_hard_negative_vps'] = []
      premises[key]['colbert_hard_negative_cps'] = []

      for i, goldset_vp in enumerate(goldset_vps):
        colbert_hard_negative_vp = []
        colbert_results = RAG.search(goldset_vp)
        for colbert_result in colbert_results:
          if colbert_result['score'] < 25 and len(colbert_hard_negative_vp) < 4:
            if colbert_result['content'] not in goldset_vps:
              colbert_hard_negative_vp.append(colbert_result['content'])
        premises[key]['colbert_hard_negative_vps'].append(colbert_hard_negative_vp)

      for i, goldset_cp in enumerate(goldset_cps):
        colbert_hard_negative_cp = []
        colbert_results = cp_RAG.search(goldset_cp)
        for colbert_result in colbert_results:
          if colbert_result['score'] < 25 and len(colbert_hard_negative_cp) < 4:
            if colbert_result['content'] not in goldset_cps:
              colbert_hard_negative_cp.append(colbert_result['content'])
        premises[key]['colbert_hard_negative_cps'].append(colbert_hard_negative_cp)

      
      hard_negatives_idxes = clip_hard_neg(key, docs_features)
      
      for idx in hard_negatives_idxes.tolist():
        if global_vps[idx] not in goldset_vps:
          premises[key]['clip_hard_negative'].append(global_vps[idx])


      while len(premises[key]['negative_vps']) < 10:
        global_vp_idx = random.randint(0, len(global_vps)-1)
        if global_vps[global_vp_idx] not in goldset_vps and global_vps[global_vp_idx] not in premises[key]['negative_vps']:
          premises[key]['negative_vps'].append(global_vps[global_vp_idx])
      
      while len(premises[key]['negative_cps']) < 10:
        global_cp_idx = random.randint(0, len(global_cps) -1)
        if global_cps[global_cp_idx] not in goldset_cps and global_cps[global_cp_idx] not in premises[key]['negative_cps']:
          premises[key]['negative_cps'].append(global_cps[global_cp_idx])

  with open('./out/premise_retrieval/premises.json', 'w') as f:
    json.dump(premises, f)

def premise_retrieval(target="vp", level="easy", vlm="llava", add_conclusion=False):
  with open('./out/premise_retrieval/cartoon_premises.json') as ctn:
    premises = json.load(ctn)

    total_results = {}


    for key in tqdm(premises.keys()):
      total_results[key] = []
      if target=="vp":
        for idx, goldset_vp in enumerate(premises[key]['goldset_vps']):
          if level=='easy':
            vp_choices = [goldset_vp] + premises[key]['negative_vps'][:4]
          elif level == 'hard_clip':
            vp_choices = [goldset_vp] + premises[key]['clip_hard_negative'][:4]
          elif level == 'hard_colbert':
            vp_choices = [goldset_vp] + premises[key]['colbert_hard_negative_vps'][idx]
          else:
            raise ValueError
          random.shuffle(vp_choices)
          vp_prompt=""
          for i, vp_choice in enumerate(vp_choices):
            if i == 0:
              vp_prompt += f"A){vp_choice}\n"
            elif i==1:
              vp_prompt += f"B){vp_choice}\n"
            elif i==2:
              vp_prompt += f"C){vp_choice}\n"
            elif i==3:
              vp_prompt += f"D){vp_choice}\n"
            elif i==4:
              vp_prompt += f"E){vp_choice}\n"
          if add_conclusion:
            prompt = visual_premise_prompt2(vp_prompt, premises[key]['conclusion'])
          else:
            prompt = visual_premise_prompt2(vp_prompt)

          if vlm=="llava":
            result = llava(key, prompt)
          elif vlm=="instructblip":
            result = instructblip(key, prompt)
          elif vlm=="blip":
            result = blip2(key, prompt)
          elif vlm=='llavanext':
            result = llavanext(key, prompt)
          elif vlm=='unifiedio':
            result = unifiedio2(key, prompt)
          elif vlm=='qwenVL':
            result = qwenVL(key, prompt)
          elif vlm=='ofa':
            result = ofa(key, prompt)
          elif vlm=='openflamingo':
            result = openflamingo(key, prompt)
          elif vlm=='minigpt':
            result = minigpt_4(key, prompt)
          else:
            raise ValueError
          
          total_results[key].append({
            'result': result,
            'gt': goldset_vp,
            'vp_choices': vp_choices
          })

      else:
        # vp_prompt = ''
        # for i, vp_choice in enumerate(premises[key]['goldset_vps']):
        #   vp_prompt += f"{i+1}. {vp_choice}\n"
        for idx, goldset_cp in enumerate(premises[key]['goldset_cps']):
          if level == 'easy':
            cp_choices = [goldset_cp] + premises[key]['negative_cps'][:4]
          else:
            cp_choices = [goldset_cp] + premises[key]['colbert_hard_negative_cps'][idx]
          random.shuffle(cp_choices)

          cp_prompt = ""
          for i, cp_choice in enumerate(cp_choices):
            if i == 0:
              cp_prompt += f"A){cp_choice}\n"
            elif i == 1:
              cp_prompt += f"B){cp_choice}\n"
            elif i == 2:
              cp_prompt += f"C){cp_choice}\n"
            elif i == 3:
              cp_prompt += f"D){cp_choice}\n"
            elif i == 4:
              cp_prompt += f"E){cp_choice}\n"
          
          prompt = commonsense_premise_prompt(cp_prompt, premises[key]['conclusion'])
          
          if vlm=="llava":
            result = llava(key, prompt)
          elif vlm=="instructblip":
            result = instructblip(key, prompt)
          elif vlm=="blip":
            result = blip2(key, prompt)
          elif vlm=='llavanext':
            result = llavanext(key, prompt)
          elif vlm=='unifiedio':
            result = unifiedio2(key, prompt)
          elif vlm=='qwenVL':
            result = qwenVL(key, prompt)
          elif vlm=='ofa':
            result = ofa(key, prompt)
          elif vlm=='openflamingo':
            result = openflamingo(key, prompt)
          elif vlm=='minigpt':
            result = minigpt_4(key, prompt)
          else:
            raise ValueError
          total_results[key].append({
            'result': result,
            'gt': goldset_cp,
            'cp_choices': cp_choices
          })

        
          # if add_conclusion:
          # })
          #   prompt = visual_premise_prompt1(vp_prompt, premises[key]['conclusion'])
          # else:
          #   prompt = visual_premise_prompt1(vp_prompt)



      # vp_choices = []
      # cp_choices = premises[key]['goldset_cps'][:2] + premises[key]['negative_cps'][:4]
      # if level=='easy':
      #   vp_choices = premises[key]['goldset_vps'][:2] + premises[key]['negative_vps'][:4]
      # elif level == 'hard_clip':
      #   vp_choices = premises[key]['goldset_vps'][:2] + premises[key]['clip_hard_negative'][:4]
      # elif level == 'hard_colbert':
      #   vp_choices = premises[key]['goldset_vps'][:2] + premises[key]['colbert_hard_negative'][:4]
      # else:
      #   raise ValueError
      # random.shuffle(vp_choices)
      # random.shuffle(cp_choices)

      # if target=="vp":
      #   vp_prompt = ''
      #   for i, vp_choice in enumerate(vp_choices):
      #     if i == 0:
      #       vp_prompt += f"A){vp_choice}\n"
      #     elif i==1:
      #       vp_prompt += f"B){vp_choice}\n"
      #     elif i==2:
      #       vp_prompt += f"C){vp_choice}\n"
      #     elif i==3:
      #       vp_prompt += f"D){vp_choice}\n"
      #     elif i==4:
      #       vp_prompt += f"E){vp_choice}\n"
      #     elif i==5:
      #       vp_prompt += f"F){vp_choice}\n"
      #     elif i==6:
      #       vp_prompt += f"G){vp_choice}\n"
      #   if add_conclusion:
      #     prompt = visual_premise_prompt1(vp_prompt, premises[key]['conclusion'])
      #   else:
      #     prompt = visual_premise_prompt1(vp_prompt)

      #   if vlm=="llava":
      #     result = llava(key, prompt)
      #   elif vlm=="instructblip":
      #     result = instructblip(key, prompt)
      #   elif vlm=="blip":
      #     result = blip2(key, prompt)
      #   elif vlm=='llavanext':
      #     result = llavanext(key, prompt)
      #   else:
      #     raise ValueError
      #   total_results[key] = {}
      #   total_results[key]['result'] = result
      #   total_results[key]['gt'] = premises[key]['goldset_vps']
      #   total_results[key]['vp_choices'] = vp_choices


#       else:
#         vp_prompt = ''
#         for i, vp_choice in enumerate(premises[key]['goldset_vps']):
#           vp_prompt += f"{i+1}. {vp_choice}\n"
#         cp_prompt = ""
#         for i, cp in enumerate(cp_choices):
#           if i == 0:
#             cp_prompt += f"A){cp}"
#           elif i == 1:
#             cp_prompt += f"B){cp}"
#           elif i == 2:
#             cp_prompt += f"C){cp}"
#           elif i == 3:
#             cp_prompt += f"D){cp}"
#           elif i == 4:
#             cp_prompt += f"E){cp}"
#           elif i == 5:
#             cp_prompt += f"F){cp}"
#           elif i == 6:
#             cp_prompt += f"G){cp}"
#           elif i == 7:
#             cp_prompt += f"H){cp}"
#           cp_prompt += ""

#         prompt = f"""
# <image>\n
# Instruction:
# This is a multiple choice question.
# Considering that the image is about "{premises[key]['conclusion']}".
# Also the important visual cues of the image are as follows.
# Visual Cues
# {vp_prompt}

# What are the commonsense premises needed to understand the image?

# Options:
# {cp_prompt}\n

# ANSWER:
# """
#         result = llava(key, prompt)
#         total_results[key] = {}
#         total_results[key]['result'] = result
#         total_results[key]['gt'] = premises[key]['goldset_cps']
#         total_results[key]['cp_choices'] = cp_choices

    print(f'./out/premise_retrieval/{vlm}_{target}_retrieval_results_{level}_conclusion_{str(add_conclusion)}.json')
    with open(f'./out/premise_retrieval/{vlm}_{target}_retrieval_results_{level}_conclusion_{str(add_conclusion)}.json', 'w') as f:
      json.dump(total_results, f)

    # else:
    #   for key in premises.keys():

def get_score(target="vp", level="easy", vlm="llava", add_conclusion=False):
  with open(f'./out/premise_retrieval/{vlm}_{target}_retrieval_results_{level}_conclusion_{str(add_conclusion)}.json') as f:
    data = json.load(f)
    accuracies = []
    for key in data.keys():
      for sub_task in data[key]:
        result = sub_task['result']
        result = result.split('endoftext')[0]
        if len(result)<10:
            matches = re.findall(r'(\d+|[A-Z])+\)?', result)
        else:
          matches = re.findall(r'(\d+|[A-Z]+|[a-z])+\)', result)
        print('=============')
        print(key)
        print(matches)
        answers = []
        for match in matches:
          label = match.split(')')[0]
          if label == 'A' or label == '1':
            answers.append(0)
          elif label == 'B' or label == '2':
            answers.append(1)
          elif label == 'C' or label == '3':
            answers.append(2)
          elif label == 'D' or label == '4':
            answers.append(3)
          elif label == 'E' or label == '5':
            answers.append(4)
          elif label == 'F' or label == '6':
            answers.append(5)
          elif label == 'G' or label == '7':
            answers.append(6)
          elif label == 'H' or label == '8':
            answers.append(7)
          elif label == 'I' or label == '9':
            answers.append(8)
          elif label == 'J' or label == '10':
            answers.append(9)
          elif label == 'K' or label == '11':
            answers.append(10)
          elif label == 'L' or label == '12':
            answers.append(11)
          else:
            print('else')
        accuracy = 0
        print(answers)

        if target == "vp":
          if len(answers) == 1:
            if sub_task['vp_choices'][answers[0]] in sub_task['gt']:
              accuracy = 1
        else:
          if len(answers) == 1:
            if sub_task['cp_choices'][answers[0]] in sub_task['gt']:
              accuracy = 1
        accuracies.append(accuracy)
        

    #accracy
    accuracy_mean = sum(accuracies) / len(accuracies)
    print(accuracies)
    print(accuracy_mean)


if __name__ == "__main__":
  parse_prmises()
  # f1score()
  # with open('./visual_argument_experiments/out/premise_retrieval/openflamingo_vp_retrieval_results_easy_conclusion_True.json') as f:
  #   data = json.load(f)
  #   for key in data.keys():
  #     # print(data[key][0]['result'].split())
  #     data[key][0]['result'] = data[key][0]['result'].split('ANSWER:')[1]
  #     print(data[key][0]['result'])

  # get_score('vp', 'hard_colbert', 'qwenVL', add_conclusion=True)