import argparse
from distutils.util import strtobool

from visarg.tasks.conclusion_deduction import conclusion_deduction

from visarg.tasks.new_premise_retrieval import parse_premises, premise_retrieval
from visarg.tasks.visual_grounding import visual_grounding
from visarg.tasks.ic_deduction import intermediate_conclusion_deduction


def parse_args():
  parser = argparse.ArgumentParser()

  parser.add_argument('--task', type=int, default=1)
  # parser.add_argument('--data_path', type=str, default='./dataset/cartoon_storage.json')
  parser.add_argument('--data', type=str, default='cartoon')
  
  # Grounding (task1)
  parser.add_argument('--ground_model', type=str, default='llava')
  parser.add_argument('--task1_result_path', type=str, default='./results/task1/')

  # Premise retrieval
  parser.add_argument('--retrieval_target', type=str, default="vp")
  parser.add_argument('--level', type=str, default="easy")
  parser.add_argument('--vlm', type=str, default='llava')
  parser.add_argument('--add_conclusion', type=lambda x: bool(strtobool(x)), default=False)
  parser.add_argument('--split_id', type=int, default=0)
  
  # Conclusion deduction options (task3)
  parser.add_argument('--deduct_model', type=str, default='llava')
  parser.add_argument('--text2con', action="store_true")
  parser.add_argument('--out_path', type=str, default='./results/task3/')
  parser.add_argument('--prompting', type=int, default=0)
  parser.add_argument('--prompt_style', type=int, default=0)
  parser.add_argument('--without_img', action="store_true")
  parser.add_argument('--few_shot', type=int, default=0)
  parser.add_argument('--few_shot_with_images', action="store_true")
  
  parser.add_argument('--with_summarized', action="store_true")

  args = parser.parse_args()
  return args

if __name__ == "__main__":
  args = parse_args()
  task = args.task

  if task == 1:
    visual_grounding(args)
  if task == 2:
    parse_premises()
    # premise_retrieval(target=args.retrieval_target, level=args.level, vlm=args.vlm, add_conclusion=args.add_conclusion, split_id=args.split_id)
  if task == 3:
    conclusion_deduction(args)
    
  if task == 40:
    intermediate_conclusion_deduction(args)
