import argparse
import json
import os


def index_train(train_dataset, overlap_mode):
    index = {}
    for question_idx, question_dict in enumerate(train_dataset):
        evidence = question_dict["positive_ctxs"][0]
        if overlap_mode == "article":
            title = evidence["title"]
            index[title] = None
        else:
            passage_id = evidence["passage_id"]
            index[passage_id] = None

    return index


def filter_dataset(train_index, overlap_mode, dev_dataset):
    new_dev = []
    for dev_example in dev_dataset:
        evidence = dev_example["positive_ctxs"][0]
        is_found_in_train = False
        if overlap_mode == "article":
            title = evidence["title"]
            if title in train_index:
                is_found_in_train = True
        else:
            passage_id = evidence["passage_id"]
            if passage_id in train_index:
                is_found_in_train = True

        if not is_found_in_train:
            new_dev.append(dev_example)
    return new_dev


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument("--train_file", type=str, required=True)
    parser.add_argument("--dev_file", type=str, required=True)
    parser.add_argument("--output_file", type=str, required=True)
    parser.add_argument("--overlap_mode", type=str, choices=["article", "passage"], default="article")

    args = parser.parse_args()

    print("Reading training data..")
    with open(args.train_file, "r") as f:
        train_dataset = json.load(f)
    print("Reading development data..")
    with open(args.dev_file, "r") as f:
        dev_dataset = json.load(f)

    overlap_mode = args.overlap_mode
    print(f"Indexing train by {overlap_mode}...")
    train_index = index_train(train_dataset, overlap_mode)
    train_dataset = None
    print(f"Finished indexing. Total {len(train_index)} unique {overlap_mode}s.")

    print(f"Filtering dev set using the index...")
    new_dev_set = filter_dataset(train_index, overlap_mode, dev_dataset)
    old_size = len(dev_dataset)
    new_size = len(new_dev_set)
    print(f"After filtering, {new_size} examples were kept out of {old_size}, which are {new_size/old_size*100:.1f}%")

    with open(args.output_file + ".json", "w") as f:
        f.write(json.dumps(new_dev_set, indent=4))
    with open(args.output_file + ".csv", "w") as f:
        for example in new_dev_set:
            f.write(f"{example['question']}\t{example['answers']}\n")