import argparse
import sys 
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from utils import *
from dataset_collection.scrape_utils import *


if __name__=='__main__':
    parser = argparse.ArgumentParser(description='Heuristic to identify the unaltered, original image among the RIS results.')
    parser.add_argument('--json_path', type=str, default='baseline/manipulation_detection_test.json',
                        help='Path to the manipulation detection predictions')
    parser.add_argument('--download_image', type=int, default=0,
                        help='If True, download the images retrieved by RIS for images predicted as manipulated.')
    parser.add_argument('--map_json_path', type=str, default='baseline/map_manipulated_original.json',
                        help='Path to the file that maps manipulated images to their identified original version.')
    args = parser.parse_args()

    if 'manipulated_original_img' not in os.listdir('dataset/'):
        os.mkdir('dataset/manipulated_original_img/')

    if args.download_image:
        #Load data
        test = load_json('dataset/test.json')
        evidence = load_json('dataset/retrieval_results/evidence.json')
        manipulation_detection_test_image_paths = [im['image path'] for im in load_json(args.json_path) if im['manipulation detection']=='manipulated']
        subset_evidence = [ev for ev in evidence if ev['image path'] in manipulation_detection_test_image_paths]
        evidence_index = [evidence.index(ev) for ev in evidence if ev['image path'] in manipulation_detection_test_image_paths]
        image_to_download = [u[1:-1].split(';')[0][1:-1] for u in subset_evidence['image url'].fillna('').to_list()] #Take for each evidence the first version of the image
        for i in range(len(subset_evidence)):
            download_image(image_to_download[i],'dataset/manipulated_original_img/'+evidence_index[i])
    
    #Identify originals with publication date heuristic
    dict_original_image = {}

    for img_path in manipulation_detection_test_image_paths:
        subset_evidence = pd.Series([ev for ev in evidence if ev['image path'] ==img_path]).sort_values(by='date').index.to_list()
        if len(subset_evidence) != 0:
            idx = subset_evidence[0]
            if str(idx)+'.png' in os.listdir('dataset/manipulated_original_img/'):
                evidence_image = 'dataset/manipulated_original_img/'+str(idx)+'.png'
                original = idx
                dict_original_image[img_path] = 'dataset/manipulated_original_img/'+str(original)+'.png'
                break

    #Save results
    with open(args.map_json_path, 'w') as file:
        json.dump(dict_original_image, file, indent=4)
