
import json
import os
import nltk

nltk.download('punkt')

from visarg.others.utils import parse_annotation, parse_reasoning_steps
from tqdm import tqdm

MODEL_CLASSES = {
  'blip2': 'BLIP2',
  'instructblip': 'InstructBLIP',
  'kosmos2': 'KOSMOS2',
  'llavanext': 'LLaVANeXT',
  'llava': 'LLaVa',
  'cogvlm': 'CogVLM',
  'qwenvlchat': 'QwenVLChat',
  'minigpt_4': 'MiniGPT_4',
  'openflamingo': 'Openflamingo',
  'ofa': 'ofa',
  'idefics2': 'idefics2',
  'otter': 'Otter',
  'unifiedio2': 'unifiedio2',
  'llama3': 'LLaMA3',
  'llama2': 'LLaMA2',
  'mistral': 'Mistral',
  'zephyr': 'Zephyr',
}

DATA = {
  'cartoon': './dataset/cartoon_storage.json',
  'social': './dataset/social_storage.json',
  'ads': './dataset/ads_storage.json',
  'new_cartoon': './dataset/new_cartoon_storage.json',
}

PROMPT = "Given an image, visual premises, and commonsense premises, draw a simple conclusion directly connected to the commonsense premises in one sentence without unnecessary prefixes. ANSWER:"

def load_model(model_name):
  model_name = model_name.lower()
  if model_name in MODEL_CLASSES:
    model_class = MODEL_CLASSES[model_name]
    try:
      module = __import__(f"models.image2con.{model_class}", fromlist=[model_name])
    except:
      module = __import__(f"models.text2con.{model_class}", fromlist=[model_name])
    return getattr(module, model_name)
  raise ValueError(f"No model found for {model_name}")

def load_prompt_func(model_name):
  model_name = model_name.lower()
  if model_name in MODEL_CLASSES:
    model_class = MODEL_CLASSES[model_name]
    try:
      module = __import__(f"models.image2con.{model_class}", fromlist=["prompt"])
    except:
      module = __import__(f"models.text2con.{model_class}", fromlist=["prompt"])
    return getattr(module, "prompt")
  raise ValueError(f"No model's prompt function found for {model_name}")

def intermediate_conclusion_deduction(args):
  data_path = DATA[args.data]
  if args.with_summarized:
    with open('./dataset/gpt_summarize.json') as f:
      summarized = json.load(f)
  with open(data_path) as data_file:
    data = json.load(data_file)
    
  print(f'== Load Model : {args.deduct_model} ==')
  deduct = load_model(args.deduct_model)
  prompt_func = load_prompt_func(args.deduct_model)

  gts = {}
  res_pos = {}
  res_all = {}

  target_ics = {}
  for key in tqdm(data.keys(), desc='IC splitting', leave=True):
    try:
      if args.with_summarized:
            for target in summarized[args.data][data[key]["image_url"]]:
              data[key]['annotation'] = data[key]['annotation'].replace(target["target"], target["gpt3.5"])
      (vps, cps, _, rs, _) = parse_annotation(data[key]['annotation'])
    except Exception as e:
      print(f"Exception on {key}: {e}")
      continue
  
    for step in rs[:-1]:
      try:
        ic_num = int(step.split('-> IC')[1][0])
        vp_idxs, cp_idxs, txt = parse_reasoning_steps(step, rs)
        target_ics[key + '_IC' + str(ic_num)] = {
          'vps': vps,
          'cps': cps,
          'gt': txt,
          'vp_idxs': vp_idxs,
          'cp_idxs': cp_idxs,
          'image_url': data[key]['image_url']
        }
      except Exception as e:
        print(f"Exception on {key}: {e}")
        continue
  
  for key in tqdm(target_ics.keys(), desc='Deduct Intermediate conclusion', leave=True):
    image_path = os.path.join('./dataset/images', target_ics[key]['image_url'].split('/')[-1])

    description = """"Visual Premises (VP)" are the important features presented in the images. "Commonsense Premises (CP)" are not visually depicted in the image but are commonly understood by people. """
    
    # VPs(positive), CPs(positive) -> IC
    try:
      vps_pos = ["Visual Premises (VP):"] + [str(i+1) + ". " + target_ics[key]['vps'][idx] for i, idx in enumerate(target_ics[key]['vp_idxs'])]
      cps_pos = ["Commonsense Premises (CP):"] + [str(i+1) + ". " + target_ics[key]['cps'][idx] for i, idx in enumerate(target_ics[key]['cp_idxs'])]
      informations = '\n\n' + '\n'.join(vps_pos) + '\n\n' + '\n'.join(cps_pos) + '\n\n'
      prompt_pos = prompt_func(prefix=description + informations, postfix=PROMPT, need_base=False)

      # VPs(ALL), CPs(positive) -> IC
      vps_all = ["Visual Premises (VP):"] + [str(i+1) + ". " + vp for i, vp in enumerate(target_ics[key]['vps'])]
      informations = '\n\n' + '\n'.join(vps_all) + '\n\n' + '\n'.join(cps_pos) + '\n\n'
      prompt_all = prompt_func(prefix=description + informations, postfix=PROMPT, need_base=False)

      con_pos = deduct(image_path, prompt_pos)
      con_all = deduct(image_path, prompt_all)
    except Exception as e:
      print(f"Error: {e} on key: {key}")
      continue
    
    # Post Processing, Extract First Sentence.
    try:
      con_pos = nltk.tokenize.sent_tokenize(con_pos)[0]
    except:
      con_pos = con_pos
    try:
      con_all = nltk.tokenize.sent_tokenize(con_all)[0]
    except:
      con_all = con_all
    gts[key] = [target_ics[key]['gt']]
    res_pos[key] = [con_pos]
    res_all[key] = [con_all]
  
  os.makedirs(args.out_path, exist_ok=True)  
  
  if args.without_img:
    detail = '_wo_img'
  else:
    detail = ''
  
  if args.data =="new_cartoon":
    args.data = "newctn"
  detail += f"_{args.data}"
  
  if args.with_summarized:
      detail += '_summarized'
  
  gts_path = os.path.join(args.out_path, args.deduct_model.lower() + '_' + 'ic' + detail + '_gts_pos.json')
  with open(gts_path, 'w') as f:
    json.dump(gts, f)

  gts_path = os.path.join(args.out_path, args.deduct_model.lower() + '_' + 'ic' + detail + '_gts_all.json')
  with open(gts_path, 'w') as f:
    json.dump(gts, f)

  res_pos_path = os.path.join(args.out_path, args.deduct_model.lower() + '_' + 'ic' + detail + '_res_pos.json')
  with open(res_pos_path, 'w') as f:
    json.dump(res_pos, f)

  res_all_path = os.path.join(args.out_path, args.deduct_model.lower() + '_' + 'ic' + detail + '_res_all.json')
  with open(res_all_path, 'w') as f:
    json.dump(res_all, f)