import json
from typing import Callable, List, Tuple
import sys
sys.path.append('..')
from multiprocessing import Pool
from functools import partial

import hydra

from config_lib.base_config import BaseConfig
from evaluation.common import BaseInstance, Statistics, SingleFileDataset
from evaluation.tasks.natural_questions_task import NaturalQuestionsTask

"""
This script outputs a list of Natural questions filenames. These are the filenames 
of instances that have annotations. 
We filter out some annotations in  natural_questions_task.NaturalQuestionsDataset._create_instance_from_document()
if they 
1. only have a long answer, but no short / boolean answer
2. have the long answer in a table
3. if they have a short or boolean answer, but no long answer
"""

def has_annotations(
        instance: BaseInstance
) -> bool:
    return bool(instance.free_text_answer)


def extraction_nodes_in_document_nodes(
        instance: BaseInstance
):
    node_ixs = [
        n.ix for n in instance.document.nodes
    ]
    missing_extraction_nodes = [
        n.ix for nodes_list in instance.extraction_nodes for n in nodes_list
        if n.ix not in node_ixs
    ]
    filter_value = not(bool(missing_extraction_nodes))
    if not filter_value:
        print(instance.example_id)
    return filter_value


def load_config(
        location: str
) -> BaseConfig:

    config_path = '../config'
    config_name = 'config'

    with hydra.initialize(
            version_base=None,
            config_path=config_path,
            job_name=''
    ):
        config = hydra.compose(
            config_name,
            overrides=[
                f'location={location}',
                f'task=natural_questions'
            ]
        )

    return config

def filter_instances_single_process(
        filters: List[Callable],
        instances: List[BaseInstance] | SingleFileDataset,
        start_end: Tuple[int, int] = None
) -> List[str]:
    if start_end is None:
        start = 0
        end = len(instances)
    else:
        start, end = start_end

    idcs = range(start, end)
    print(idcs)

    filtered_instance_ids = []
    for idx in idcs:
        instance = instances[idx]
        filter_results = []
        for filter in filters:
            filter_results.append(filter(instance))
        if all(filter_results):
            filtered_instance_ids.append(str(instance.example_id))
    return filtered_instance_ids


def filter_instances_multi_process(
        filters: List[Callable],
        instances: List[BaseInstance] | SingleFileDataset,
        n_processes: int
):
    """Do multi processing using filter_instances_single_process.
    Each process should receive different start and end parameters."""

    n_instances = len(instances)
    n_instances_per_process = n_instances // n_processes
    start_end = []
    for i in range(n_processes):
        start = i * n_instances_per_process
        if i == n_processes - 1:
            end = n_instances
        else:
            end = start + n_instances_per_process
        start_end.append((start, end))

    with Pool(n_processes) as pool:
        partial_filter_instances_single_process = partial(
            filter_instances_single_process,
            filters,
            instances
        )
        results = pool.map(
            partial_filter_instances_single_process,
            start_end
        )

    filtered_instance_ids = []
    for result in results:
        filtered_instance_ids.extend(result)

    return filtered_instance_ids

def get_filenames(
        example_ids: List[str]
) -> List[str]:
    filenames = []
    for example_id in example_ids:
        filenames.append(f'{example_id}.json')
    return filenames

def main(
        location: str,
        output_suffix: str = '_no_missing_extraction_nodes',
        n_processes: int = 64
):
    config = load_config(location)
    stats = Statistics(
        'none',
        'natural_questions',
        'none',
        config
    )
    task = NaturalQuestionsTask(
        config,
        stats
    )

    train_instances = task.train_instances
    filtered_train_example_ids = filter_instances_multi_process(
        [extraction_nodes_in_document_nodes],
        train_instances,
        n_processes
    )
    train_filenames = get_filenames(filtered_train_example_ids)
    out_filepath = config.location.datasets / 'natural_questions' / 'natural_questions_itg' / f'train{output_suffix}.json'
    with open(out_filepath, 'w') as out_file:
        json.dump(train_filenames, out_file, indent=None)

    filtered_dev_example_ids = filter_instances_multi_process(
        [extraction_nodes_in_document_nodes],
        task.dev_instances,
        n_processes
    )
    dev_filenames = get_filenames(filtered_dev_example_ids)
    out_filepath = config.location.datasets / 'natural_questions' / 'natural_questions_itg' / f'dev{output_suffix}.json'
    with open(out_filepath, 'w') as out_file:
        json.dump(dev_filenames, out_file, indent=None)

    filtered_test_example_ids = filter_instances_multi_process(
        [extraction_nodes_in_document_nodes],
        task.test_instances,
        n_processes
    )
    test_filenames = get_filenames(filtered_test_example_ids)
    out_filepath = config.location.datasets / 'natural_questions' / 'natural_questions_itg' / f'test{output_suffix}.json'
    with open(out_filepath, 'w') as out_file:
        json.dump(test_filenames, out_file, indent=None)

    return

if __name__ == '__main__':

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--remote_debug',
        action='store_true'
    )

    args = parser.parse_args()

    if args.remote_debug:
        # Set up remote debugging
        import pydevd_pycharm
        pydevd_pycharm.settrace('10.167.11.14', port=3851, stdoutToServer=True, stderrToServer=True)

    main(
        'shared'
    )
