import json
import re
import base64
from PIL import Image
import io

def interUnion(boxA, boxB):
  xA = max(boxA[0], boxB[0])
  yA = max(boxA[1], boxB[1])
  xB = min(boxA[2], boxB[2])
  yB = min(boxA[3], boxB[3])

  interArea = max(0, xB-xA+1) * max(0, yB-yA+1)
  AArea = (boxA[2]-boxA[0]+1) * (boxA[3]-boxA[1]+1)
  BArea = (boxB[2]-boxB[0]+1) * (boxB[3]-boxB[1]+1)

  iou = interArea/(AArea + BArea - interArea)

  return iou

def parse_annotation(annotation):
  vp = []
  cp = []
  c = ""
  rs = []
  ctg = ""
  is_vp = False
  is_cp = False
  is_c = False
  is_rs = False
  is_ctg = False
  lines = annotation.split('\n')
  
  for line in lines:
    if "visual premise" in line.lower():
      is_vp = True
    elif is_vp and "commonsense premise" in line.lower():
      is_vp = False
      is_cp = True
    elif is_cp and "conclusion" in line.lower():
      is_cp = False
      is_c = True
    elif is_c and "reasoning step" in line.lower():
      is_c = False
      is_rs = True
    elif is_rs and "category" in line.lower():
      is_rs = False
      ctg = line.split(':')[1].strip()
    else:
      if line != "":
        if is_vp:
          vp.append(re.sub(r'^\d+\.\s', '', line))
        elif is_cp:
          cp.append(re.sub(r'^\d+\.\s', '', line))
        elif is_c:
          c = line
        elif is_rs:
          rs.append(line)
        else:
          pass

  return (vp, cp, c, rs, ctg)

# def parse_reasoning_steps(target):
#   start_idx = target.index('(')
#   end_idx = target.index('): ')
#   tree = target.split('): ')[0]
#   text = target.split('): ')[1]
#   vp_pattern = r'VP\d+|vp\d+'
#   cp_pattern = r'CP\d+|cp\d+'
#   vps = re.findall(vp_pattern, tree.split('->')[0])
#   cps = re.findall(cp_pattern, tree.split('->')[0])

#   vp_idx = []
#   cp_idx = []
#   for vp in vps:
#     vp_idx.append(int(vp[-1]) - 1)
#   for cp in cps:
#     cp_idx.append(int(cp[-1]) - 1)
  
#   return vp_idx, cp_idx, text

  # target[start_idx: end_idx]

def track_tree(cur_tree, reasoning_steps):
  ic_pattern = r'IC\d+|ic\d+'
  vp_pattern = r'VP\d+|vp\d+'
  cp_pattern = r'CP\d+|cp\d+'
  
  vps = re.findall(vp_pattern, cur_tree.split('->')[0])
  cps = re.findall(cp_pattern, cur_tree.split('->')[0])

  ics = re.findall(ic_pattern, cur_tree.split('->')[0])
  if len(ics) == 0:
    return vps, cps
  else:

    for ic in ics:
      cur_tree = ''
      for rs in reasoning_steps:
        if ic in rs.split('): ')[0].split('->')[1] or ic.lower() in rs.split('): ')[0].split('->')[1]:
          cur_tree = rs.split('): ')[0]
          temp_vps, temp_cps = track_tree(cur_tree, reasoning_steps)
          vps = vps + temp_vps
          cps = cps + temp_cps

    return vps, cps
      


def parse_reasoning_steps(target, reasoning_steps):
  cur_tree = target.split('): ')[0]
  vps, cps = track_tree(cur_tree, reasoning_steps)
  text = target.split('): ')[1]
  
  vp_idx = []
  cp_idx = []
  for vp in vps:
    vp_idx.append(int(vp[-1]) - 1)
  for cp in cps:
    cp_idx.append(int(cp[-1]) - 1)
  
  return vp_idx, cp_idx, text


def resize_image(image_path, output_size=(64, 64)):
    with Image.open(image_path) as img:
        img = img.resize(output_size)
        if img.mode == 'RGBA':
            img = img.convert('RGB')
        buffered = io.BytesIO()
        img.save(buffered, format="JPEG")
        return base64.b64encode(buffered.getvalue()).decode('utf-8')


if __name__ == "__main__":
  with open('./temp1/social_storage.json') as cartoon:
    data = json.load(cartoon)
    for key in data.keys():
      # if int(key) == 700:
      is_normal = True
      (vp, cp, c, rs, ctg) = parse_annotation(data[key]["annotation"])
      if len(vp) == 0 or len(cp) == 0 or c == "" or len(rs) == 0 or ctg == "":
        is_normal = False
      
      if is_normal is not True:
        print(key)
  print('end')
