import argparse
import csv
import json
import multiprocessing
import time
import html
from multiprocessing import Pool


def split_document_to_chunks(text, title, chunk_size, pad_last_passage=True):
    chunks = []
    split = text.split()
    if pad_last_passage and len(split) > chunk_size:
        # Padding the last passage with the beginning of the article (similar to DPR)
        split_with_pad = split + split[:chunk_size]
    else:
        split_with_pad = split

    for chunk_start_idx in range(0, len(split), chunk_size):
        chunk = split_with_pad[chunk_start_idx:(chunk_start_idx + chunk_size)]
        chunk_txt = " ".join(chunk)
        chunks.append((chunk_txt, title))
    return chunks


def process_chunk(args, documents, rank):
    all_psgs = []
    start_time = time.time()
    for i, doc in enumerate(documents):
        if rank == 0 and i > 0 and i % 2000 == 0:
            print(f"Finished processing {i}/{len(documents)} documents in thread #0. "
                  f"Took {(time.time() - start_time)/60:.1f} minutes")
        doc = json.loads(doc)
        title = doc["wikipedia_title"]
        sents = []
        if len(doc["text"]) == 0:
            continue

        for i, sent in enumerate(doc["text"]):
            if (not args.include_sections) and sent.startswith("Section::::"):
                continue
            if (not args.include_lists) and sent.startswith("BULLET::::"):
                continue

            sent = sent.strip()
            if i == 0:
                sent = html.unescape(sent)
            if sent.startswith("BULLET::::"):
                sent = sent.replace("BULLET::::", args.bullet_str)
            sents.append(sent)

        all_text = " ".join(sents)
        all_psgs.extend(split_document_to_chunks(all_text, title, args.chunk_size, pad_last_passage=True))

    return all_psgs


def main(args):
    assert not args.include_sections

    with open(args.input_file, "r") as f:
        all_docs = f.readlines()
    print(f"Finished reading {len(all_docs)} documents")

    num_threads = args.num_threads if args.num_threads is not None else multiprocessing.cpu_count()
    print(f"Processing with {num_threads} processes")
    docs_in_thread = len(all_docs) // num_threads + 1

    params = [(args, all_docs[i:i+docs_in_thread], i) for i in range(0, len(all_docs), docs_in_thread)]
    with Pool(num_threads) as p:
        results = p.starmap(process_chunk, params)
    print(f"Finished processing, now writing to {args.output_file}")

    with open(args.output_file, "w") as f:
        writer = csv.writer(f, delimiter="\t")
        curr_index = 1
        writer.writerow(["id", "text", "title"])
        for res in results:
            for psg_txt, title in res:
                writer.writerow([curr_index, psg_txt, title])
                curr_index += 1
    print(f"Done! Wrote {curr_index-1} passages")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument("--input_file", required=True, type=str)
    parser.add_argument("--output_file", required=True, type=str)
    parser.add_argument("--chunk_size", type=int, default=100)
    parser.add_argument("--num_threads", type=int, default=None)
    parser.add_argument("--include_sections", action="store_true")
    parser.add_argument("--include_lists", action="store_true")
    parser.add_argument("--bullet_str", type=str, default="")

    args = parser.parse_args()
    main(args)