
import json
import os
import nltk

nltk.download('punkt')

from visarg.others.utils import parse_annotation
from tqdm import tqdm

from visarg.others.prompt_styles import PROMPT_STYLES

EXAMPLES = [
"""(Example 1)
Visual Premises (VP):
1. A figure is draped in a shroud-like garment with a predominantly dark background.
2. The garment has a prominent crescent moon and star symbol on the upper part.
3. There is a Star of David symbol visible on the lower part of the garment.
4. The image is illuminated with a gradient of purple light from top to bottom.

Commonsense Premises (CP):
1. The crescent moon and star are widely recognized as symbols associated with Islam and several countries with a Muslim majority.
2. The Star of David is a well-known symbol of Judaism and is often associated with the Jewish people and the state of Israel.
3. The use of dark backgrounds and shrouds often symbolizes mystery, death, or mourning.
4. Purple light can symbolize spirituality, wisdom, or suffering.

Conclusion (C):
The image represents the juxtaposition of Islamic and Jewish symbols, suggesting themes of coexistence, conflict, or mutual respect between the two religions and cultures.

Reasoning Steps:
(VP1, VP2, CP1 -> IC1): The crescent moon and star on the upper part of the garment suggest the figure represents or is related to Islam.
(VP3, CP2 -> IC2): The Star of David on the lower part of the garment suggests the figure also has a connection to Judaism.
(VP4, CP3 -> IC3): The dark background and shroud-like garment imply a somber or reflective theme.
(CP4 -> IC4): The purple light adds a layer of spirituality or suffering to the image's mood.
(IC1, IC2, IC3, IC4 -> C): The image represents the juxtaposition of Islamic and Jewish symbols, suggesting themes of coexistence, conflict, or mutual respect between the two religions and cultures.

Category:
Religious and Cultural Commentary""",
"""(Example 2)
Visual Premises (VP):
1. The background features a world map painted on a brick wall.
2. Two windows are integrated into the map, one located in the region representing Africa and the other in Asia.
3. A person is leaning out of the upper window, handing a FedEx package to another person leaning out of the lower window.
4. The FedEx logo is visible on the package and in the bottom right corner of the image.

Commonsense Premises (CP):
1. FedEx is a well-known international courier delivery service.
2. Windows are commonly associated with the idea of communication and connection between separate spaces.
3. The map background implies a global context, suggesting international reach or connectivity.
4. The act of passing a package between windows suggests ease and directness in the delivery process.

Conclusion (C):
The image emphasizes FedEx's global reach and the ease of international shipping, suggesting that sending packages across the world is as simple as passing them between neighboring windows.

Reasoning Steps:
(VP1, CP3 -> IC1): The world map background indicates a global context, suggesting the image is about international connections.
(VP2, VP3, CP2 -> IC2): The presence of windows and the act of handing over a package implies ease and directness in the delivery process.
(VP4, CP1 -> IC3): The FedEx branding reinforces the association with international courier services.
(IC1, IC2, IC3 -> C): The image emphasizes FedEx's global reach and the ease of international shipping, suggesting that sending packages across the world is as simple as passing them between neighboring windows.

Category:
Advertising and Marketing""",
"""(Example 3)
Visual Premises (VP):

1. A central figure representing Uncle Sam stands with arms crossed and a stern expression.
2. Two men are kneeling on either side of Uncle Sam, one holding a "Save Israel" sign and the other holding a "Save Ukraine" sign.
3.The man holding the "Save Israel" sign is depicted blowing air into Uncle Sam's pocket, causing it to swell.
4. The man holding the "Save Ukraine" sign is depicted blowing air into Uncle Sam's other pocket, but it is not swelled enough.

Commonsense Premises (CP):
1. Uncle Sam is a common personification of the United States government or American interests.
2. The act of blowing air into someone's pocket is a symbolic gesture indicating an attempt to inflate or influence wealth or favor.
3. "Save Israel" and "Save Ukraine" signs indicate a call for support or aid from the United States for these respective countries.
4. The image suggests a power dynamic where the United States holds significant influence over the fates of Israel and Ukraine.
5. The swelling of one pocket and not the other implies a disparity in the level of influence or support being offered or received.

Conclusion (C):
The image critiques the uneven support or favoritism of the United States towards Israel over Ukraine, highlighting a disparity in the influence or aid being given to these countries.

Reasoning Steps:
(VP1, CP1 -> IC1): Uncle Sam represents the United States, indicating the central role of American influence.
(VP2, CP3 -> IC2): The men kneeling and holding signs symbolize Israel and Ukraine seeking support from the United States.
(VP3, VP4, CP2, CP4 -> IC3): The act of blowing air into Uncle Sam's pockets indicates attempts by Israel and Ukraine to influence or gain favor from the United States.
(IC2, CP5 -> IC4): The swelling of one pocket and not the other implies a disparity in the influence or support being offered or received.
(IC1, IC3, IC4 -> C): The image critiques the uneven support or favoritism of the United States towards Israel over Ukraine, highlighting a disparity in the influence or aid being given to these countries.

Category:
Political Commentary"""
]

MODEL_CLASSES = {
  'blip2': 'BLIP2',
  'instructblip': 'InstructBLIP',
  'kosmos2': 'KOSMOS2',
  'llavanext': 'LLaVANeXT',
  'llava': 'LLaVa',
  'cogvlm': 'CogVLM',
  'qwenvlchat': 'QwenVLChat',
  'minigpt_4': 'MiniGPT_4',
  'openflamingo': 'Openflamingo',
  'ofa': 'ofa',
  'idefics2': 'idefics2',
  'otter': 'Otter',
  'unifiedio2': 'unifiedio2',
  'llama3': 'LLaMA3',
  'llama2': 'LLaMA2',
  'mistral': 'Mistral',
  'zephyr': 'Zephyr',
}

DATA = {
  'cartoon': './dataset/cartoon_storage.json',
  'social': './dataset/social_storage.json',
  'ads': './dataset/ads_storage.json',
  'new_cartoon': './dataset/new_cartoon_storage.json',
}

def load_model(model_name):
  model_name = model_name.lower()
  if model_name in MODEL_CLASSES:
      model_class = MODEL_CLASSES[model_name]
      try:
        module = __import__(f"visarg.models.image2con.{model_class}", fromlist=[model_name])
      except:
        module = __import__(f"visarg.models.text2con.{model_class}", fromlist=[model_name])
      return getattr(module, model_name)
  raise ValueError(f"No model found for {model_name}")

def load_prompt_func(model_name):
  model_name = model_name.lower()
  if model_name in MODEL_CLASSES:
    model_class = MODEL_CLASSES[model_name]
    try:
      module = __import__(f"visarg.models.image2con.{model_class}", fromlist=["prompt"])
    except:
      module = __import__(f"visarg.models.text2con.{model_class}", fromlist=["prompt"])
    return getattr(module, "prompt")
  raise ValueError(f"No model's prompt function found for {model_name}")

def conclusion_deduction(args):
  data_path = DATA[args.data]
  if args.with_summarized:
    with open('./dataset/gpt_summarize.json') as f:
      summarized = json.load(f)
  with open(data_path) as data_file:
    data = json.load(data_file)
    
    print(f'== Load Model : {args.deduct_model} ==')
    deduct = load_model(args.deduct_model)
    prompt_func = load_prompt_func(args.deduct_model)

    gts = {}
    res = {}
    
    for key in tqdm(data.keys(), desc='Deduct conclusion', leave=True):
      try:
        if args.with_summarized:
          for target in summarized[args.data][data[key]["image_url"]]:
            data[key]['annotation'] = data[key]['annotation'].replace(target["target"], target["gpt3.5"])
        (vps, cps, c, rs, ctg) = parse_annotation(data[key]['annotation'])
      except Exception as e:
        print(f"{key}")
        continue
      image_path = os.path.join('./dataset/images', data[key]['image_url'].split('/')[-1])
      
      # vps, cps numbering
      vps = ["Visual Premises (VP):"] + [str(idx+1) + ". " + vp for idx, vp in enumerate(vps)]
      cps = ["Commonsense Premises (CP):"] + [str(idx+1) + ". " + cp for idx, cp in enumerate(cps)]
      rs = ["Reasoning Step:"] + rs
      
      need_base = args.prompt_style < 1
      prompt = ''
      if args.prompting:
        if args.prompting == 1:
          # Image and vps -> Conclusion
          description = PROMPT_STYLES[args.prompt_style]["vp_desc"]
          if args.few_shot:
            description += "\nHere is some examples"
            for i in range(args.few_shot):
              ex_vps, ex_cps, ex_c, ex_rs, ex_ctg = parse_annotation(EXAMPLES[i])
              
              ex_vps = ["Visual Premises (VP):"] + [str(idx+1) + ". " + vp for idx, vp in enumerate(ex_vps)]
              ex_cps = ["Commonsense Premises (CP):"] + [str(idx+1) + ". " + cp for idx, cp in enumerate(ex_cps)]
              ex_rs = ["Reasoning Step:"] + ex_rs

              description += f'\n\n(Example {i+1})\n' + '\n'.join(ex_vps) + '\n\n'
              description += f'Conclusion: {ex_c}\n'
              
          informations = '\n\n(Task Part)\n' + '\n'.join(vps) + '\n\n'
          prefix = description + informations
          
        elif args.prompting == 2:
          # Image and cps -> Conclusion
          description = PROMPT_STYLES[args.prompt_style]["cp_desc"]
          if args.few_shot:
            description += "\nHere is some examples"
            for i in range(args.few_shot):
              ex_vps, ex_cps, ex_c, ex_rs, ex_ctg = parse_annotation(EXAMPLES[i])
              
              ex_vps = ["Visual Premises (VP):"] + [str(idx+1) + ". " + vp for idx, vp in enumerate(ex_vps)]
              ex_cps = ["Commonsense Premises (CP):"] + [str(idx+1) + ". " + cp for idx, cp in enumerate(ex_cps)]
              ex_rs = ["Reasoning Step:"] + ex_rs

              description += f'\n\n(Example {i+1})\n' + '\n'.join(ex_cps) + '\n\n'
              description += f'Conclusion: {ex_c}\n'
          informations = '\n\n(Task Part)\n' + '\n'.join(cps) + '\n\n'
          prefix = description + informations
        
        elif args.prompting == 3:
          # Image and vps, cps -> Conclusion
          description = PROMPT_STYLES[args.prompt_style]["vp_desc"] + PROMPT_STYLES[args.prompt_style]["cp_desc"]
          if args.few_shot:
            description += "\nHere is some examples"
            for i in range(args.few_shot):
              ex_vps, ex_cps, ex_c, ex_rs, ex_ctg = parse_annotation(EXAMPLES[i])
              
              ex_vps = ["Visual Premises (VP):"] + [str(idx+1) + ". " + vp for idx, vp in enumerate(ex_vps)]
              ex_cps = ["Commonsense Premises (CP):"] + [str(idx+1) + ". " + cp for idx, cp in enumerate(ex_cps)]
              ex_rs = ["Reasoning Step:"] + ex_rs

              description += f'\n\n(Example {i+1})\n' + '\n'.join(ex_vps) + '\n\n' + '\n'.join(ex_cps) + '\n\n'
              description += f'Conclusion: {ex_c}\n'
          informations = '\n\n(Task Part)\n' + '\n'.join(vps) + '\n\n' + '\n'.join(cps) + '\n\n'
          prefix = description + informations
          
        elif args.prompting == 4:
          # Image and vps, cps, reasoning steps -> Conclusion
          description = PROMPT_STYLES[args.prompt_style]["vp_desc"] + PROMPT_STYLES[args.prompt_style]["cp_desc"] + PROMPT_STYLES[args.prompt_style]["rs_desc"]
          if args.few_shot:
            description += "\nHere is some examples"
            for i in range(args.few_shot):
              ex_vps, ex_cps, ex_c, ex_rs, ex_ctg = parse_annotation(EXAMPLES[i])
              
              ex_vps = ["Visual Premises (VP):"] + [str(idx+1) + ". " + vp for idx, vp in enumerate(ex_vps)]
              ex_cps = ["Commonsense Premises (CP):"] + [str(idx+1) + ". " + cp for idx, cp in enumerate(ex_cps)]
              ex_rs = ["Reasoning Step:"] + ex_rs

              description += f'\n\n(Example {i+1})\n' + '\n'.join(ex_vps) + '\n\n' + '\n'.join(ex_cps) + '\n\n'
              ex_rslines = '\n'.join(ex_rs)
              ex_rslines = ex_rslines.split("-> C):")[0] + "-> C)" + '\n\n'
              description += f'Conclusion: {ex_c}\n'
          informations = '\n\n(Task Part)\n' + '\n'.join(vps) + '\n\n' + '\n'.join(cps) + '\n\n'
          rs_lines = '\n'.join(rs)
          rs_lines = rs_lines.split("-> C):")[0] + "-> C)" + "\n\n"
          prefix = description + informations + rs_lines
        
        if need_base:
          postfix = None
        else:
          postfix = PROMPT_STYLES[args.prompt_style]["task_desc"]

        prompt = prompt_func(prefix=prefix, postfix=postfix, need_base=need_base)
        
        if args.few_shot_with_images:
          image_path = [f'./dataset/images/example{i+1}.png' for i in range(args.few_shot)] + [image_path]
          

      try:
        if not args.text2con:
          if args.without_img:
            image_path=None
          
          if args.prompting:
            con = deduct(image_path, prompt)
          else:
            if args.few_shot:
              example_cs = []
              for i in range(args.few_shot):
                ex_vps, ex_cps, ex_c, ex_rs, ex_ctg = parse_annotation(EXAMPLES[i])
                example_cs.append(f'(Example {i+1})\nConclusion: {ex_c}')
              prompt = prompt_func(prefix='\n'.join(example_cs) + '\n(Task Part)')
            else:
              prompt = None
            con = deduct(image_path, prompt)
        else:
          con = deduct(prompt)
      except Exception as e:
        print(f"Exception {e} on {key}")
        
      # Post Processing, Extract First Sentence.
      try:
        con = nltk.tokenize.sent_tokenize(con)[0]
      except:
        con = con
      gts[key] = [c]
      res[key] = [con]
    
    if args.without_img:
      detail = '_wo_img'
    else:
      detail = ''
    
    if args.data =="new_cartoon":
      args.data = "newctn"
    detail += f"_{args.data}"
    
    if args.with_summarized:
      detail += '_summarized'
    
    if args.prompt_style:
      args.out_path = os.path.join(args.out_path, f'style{args.prompt_style}')
      
    if args.few_shot:
      if args.few_shot_with_images:
        args.out_path = os.path.join(args.out_path, f'fewshot{args.few_shot}_w_img')
      else:
        args.out_path = os.path.join(args.out_path, f'fewshot{args.few_shot}')
      
    
    os.makedirs(args.out_path, exist_ok=True)  
    
    gts_path = os.path.join(args.out_path, args.deduct_model.lower() + '_' + str(args.prompting) + detail + '_gts.json')
    with open(gts_path, 'w') as f:
      json.dump(gts, f)
    # res["prompt_example"] = prompt
    res_path = os.path.join(args.out_path, args.deduct_model.lower() + '_' + str(args.prompting) + detail + '_res.json')
    with open(res_path, 'w') as f:
      json.dump(res, f)