from google.cloud import vision
import os 
from tqdm import tqdm
import time
import sys 
sys.path.insert(1, os.path.join(sys.path[0], '..'))
from utils import *
from dataset_collection.scrape_utils import *
import argparse



def detect_web(path,how_many_queries=30):
    """
    Detects web annotations given an image.
    """
    client = vision.ImageAnnotatorClient()

    with open(path, "rb") as image_file:
        content = image_file.read()

    image = vision.Image(content=content)

    response = client.web_detection(image=image, max_results=how_many_queries)
    annotations = response.web_detection

    page_urls = []
    matching_image_urls = {}

    if annotations.pages_with_matching_images:
        print(
            "\n{} Pages with matching images found:".format(
                len(annotations.pages_with_matching_images)
            )
        )
        
        for page in annotations.pages_with_matching_images:
            page_urls.append(page.url)
            if page.full_matching_images:
                matching_image_urls[page.url] = [image.url for image in page.full_matching_images]
            if page.partial_matching_images: 
                matching_image_urls[page.url] = [image.url for image in page.partial_matching_images] 
    else:
        print('No matching images found for ' + path)
        # print(annotations)

    if response.error.message:
        raise Exception(
            "{}\nFor more info on error messages, check: "
            "https://cloud.google.com/apis/design/errors".format(response.error.message)
        )
    
    return page_urls, matching_image_urls



if __name__=='__main__':
    parser = argparse.ArgumentParser(description='Collect evidence using Google Reverse Image Search.')
    parser.add_argument('--collect_google', type=int, default=0, 
                        help='Whether to collect evidence URLs with the google API. If 0, it is assumed that a file containing URLs already exists.')
    parser.add_argument('--google_vision_api_key', type=str,  default= " ", #Provide your own key here as default value
                        help='Your key to access the Google Vision services, including the web detection API.')  
    parser.add_argument('--image_path', type=str, default='dataset/processed_img/',
                        help='The folder where the images are stored.') 
    parser.add_argument('--output_dir_urls', type=str, default='dataset/retrieval_results/ris.txt',
                        help='The txt file to store the RIS urls.') 
    parser.add_argument('--scrape_with_trafilatura', type=int, default=0, 
                        help='Whether to scrape the evidence URLs with trafilatura. If 0, it is assumed that a file containing the scraped webpages already exists.') 
    parser.add_argument('--trafilatura_path', type=str, default='dataset/retrieval_results/trafilatura_data.json',
                        help='The json file to store the scraped trafilatura  content as a json file.')
    parser.add_argument('--json_path', type=str, default='dataset/retrieval_results/evidence.json',
                        help='The json file to store the text evidence as a json file.')
    parser.add_argument('--max_results', type=int, default=50,
                        help='The maximum number of web-pages to collect with the web detection API.') 
    parser.add_argument('--sleep', type=int, default=3,
                        help='The waiting time between two web detection API calls') 
    

    args = parser.parse_args()
    key = os.getenv(args.google_vision_api_key)

    #Create directories if they do not exist yet
    if not 'retrieval_results'  in os.listdir('dataset/'):
        os.mkdir('dataset/retrieval_results/')

    #Google RIS
    if args.collect_google:
        with open(args.output_dir_urls,'a',encoding='utf-8') as output_file:
            for path in tqdm(os.listdir(args.image_path)):
                urls, images = detect_web(args.image_path +path,args.max_results)
                output_file.write(f"{args.image_path + path} | {';'.join(urls)} | {str(images)} \n")
                time.sleep(args.sleep)


        #Apply filtering to the URLs to remove content produced by FC organizations, and content that is not scrapable
        selected_data = get_filtered_retrieval_results(args.output_dir_urls)
    else:
        selected_data = load_json('dataset/retrieval_results/evidence_urls.json')

    
    urls = [d['raw url'] for d in selected_data]
    images = [d['image urls'] for d in selected_data]

    if args.scrape_with_trafilatura:
        #Collect results with Trafilatura
        output = []
        for u in tqdm(range(len(urls))):
            output.append(extract_info_trafilatura(urls[u],images[u]))
            #Only store in json file every 50 evidence
            if u%50==0:
                save_result(output,args.trafilatura_path) 
                output = []
    
    #Save all results in a Pandas Dataframe
    evidence_trafilatura = load_json(args.trafilatura_path)
    dataset = load_json('dataset/train.json') + load_json('dataset/val.json')  + load_json('dataset/test.json')
    evidence = merge_data(evidence_trafilatura, selected_data, dataset).fillna('').to_dict(orient='records')
    # Save the list of dictionaries as a JSON file
    with open(args.json_path, 'w') as file:
        json.dump(evidence, file, indent=4)