
import json
import os
import re

from visarg.others.utils import interUnion, parse_annotation
from tqdm import tqdm

MODEL_CLASSES = {
  'lisa': 'lisa',
  'uninext_h': 'UNINEXT_H',
  'llava': 'llava',
  'ofa': 'ofa',
  'qwenvl': 'QwenVL',
  'lisa': 'lisa',
}


DATA = {
  'cartoon': './dataset/cartoon_storage.json',
  'social': './dataset/social_storage.json',
  'ads': './dataset/ads_storage.json',
  'new_cartoon': './dataset/new_cartoon_storage.json',
}

def load_model(model_name):
  model_name = model_name.lower()
  if model_name in MODEL_CLASSES:
      model_class = MODEL_CLASSES[model_name]
      module = __import__(f"visarg.models.grounding.{model_class}", fromlist=[model_name])
      return getattr(module, model_name)
  raise ValueError(f"No model found for {model_name}")


def visual_grounding(args):
  with open(DATA[args.data]) as f:
    data = json.load(f)
    print(len(data.keys()))
    ground = load_model(args.ground_model)
    
    results = {}

    local_scores = []
    global_scores = []
    local_ious = []
    global_ious = []

    for key, value in tqdm(data.items(), desc=f'Visual Grounding {args.ground_model} on data {args.data}', leave=True):
      local_score = []
      local_iou = []
      # Data Sanity Check
      if 'bbox' not in value.keys():
          print(f"BBox not in {key}")
          continue
        
      image_path = os.path.join('./dataset/images', value["image_url"].split('/')[-1])
      (vp, cp, c, rs, ctg) = parse_annotation(value["annotation"])
      if len(value['bbox']) != len(vp):
        print("Wrong length", key)
        continue
      
      tgt_vps = []
      ground_bboxes = []
      for i, item in enumerate(vp):
        if '"' not in vp and 'text' not in vp and 'bubble' not in vp and 'logo' not in vp:
          tgt_vps.append(item)
          ground_bboxes.append(value['bbox'][i])

      bboxes = ground(image_path, tgt_vps)

      image_result = {
          "vps": tgt_vps,
          "ious": [],
          "gts": [],
          "preds": [],
        }

      for pred, gt in zip(bboxes, ground_bboxes):
        gt_bbox = [
            gt["startX"],
            gt["startY"],
            gt["startX"] + gt["w"],
            gt["startY"] + gt["h"],
          ]
        iou = interUnion(pred, gt_bbox)

        image_result["ious"].append(iou)
        image_result["gts"].append(gt_bbox)
        image_result["preds"].append(pred)
        local_iou.append(iou)
        global_ious.append(iou)
        if iou > 0.5:
          local_score.append(1)
          global_scores.append(1)
        else:
          local_score.append(0)
          global_scores.append(0)
      

      results[image_path] = image_result
      local_score = sum(local_score)/len(local_score)
      local_scores.append(local_score)
      local_iou = sum(local_iou)/len(local_iou)
      local_ious.append(local_iou)





  os.makedirs(args.task1_result_path, exist_ok=True)
  with open(os.path.join(args.task1_result_path, f"{args.ground_model}_{args.data}_result.json"), 'w') as r_file:
    json.dump(results, r_file)

  print('local iou : ', sum(local_ious)/len(local_ious))
  print('local score : ', sum(local_scores)/len(local_scores))

  print('global iou : ', sum(global_ious)/len(global_ious))
  print('global score : ', sum(global_scores)/len(global_scores))
