import argparse
import csv
import logging
import os
import pickle
import time
from multiprocessing.pool import Pool

from transformers import BertTokenizer

logger = logging.getLogger()
logger.setLevel(logging.INFO)
if logger.hasHandlers():
    logger.handlers.clear()
console = logging.StreamHandler()
logger.addHandler(console)


def process_single_shard(args, rows, tokenizer, output_file):
    results = []
    for idx, text, title in rows:
        encoded = tokenizer.encode(
            title,
            text_pair=text,
            add_special_tokens=True,
            max_length=args.sequence_length,
            pad_to_max_length=False,
            truncation=True
        )
        results.append((idx, encoded))
    with open(output_file, "wb") as f:
        pickle.dump(results, f)


def prepare_params_for_shard(args, tokenizer, rows, num_shards, shard_idx):
    assert 0 <= shard_idx < num_shards
    shard_size = int(len(rows) / num_shards)
    shard_start = shard_idx * shard_size
    if shard_idx == (num_shards - 1):
        shard_rows = rows[shard_start:]
    else:
        shard_rows = rows[shard_start:shard_start+shard_size]

    output_file = os.path.join(args.output_dir, f"tokenized_passages_{shard_idx}.pkl")
    return args, shard_rows, tokenizer, output_file


def main(args):
    tokenizer = BertTokenizer.from_pretrained(args.tokenizer_name, do_lower_case=args.do_lower_case)

    logger.info("reading data from file=%s", args.ctx_file)
    start_time = time.time()
    rows = []
    error_counter = 0
    with open(args.ctx_file) as tsvfile:
        reader = csv.reader(tsvfile, delimiter="\t")
        for row in reader:
            if len(row) != 3:
                logger.info(f"Detected a row of len {len(row)}, so skipping. Row: {row}")
                error_counter += 1
                continue
            if row[0] == "id":
                continue
            rows.append(row)

    logger.info(f"Done. Took {(time.time() - start_time) / 60:.1f} minutes")
    logger.info(f"Overall there were {error_counter} invalid passages")

    # Creating the output directory. If exists, crash.
    os.makedirs(args.output_dir, exist_ok=False)

    params = [prepare_params_for_shard(args, tokenizer, rows,
                                       args.num_processes_and_shards,
                                       shard_idx) for shard_idx in range(args.num_processes_and_shards)]
    with Pool(args.num_processes_and_shards if args.num_processes_and_shards else None) as p:
        p.starmap(process_single_shard, params)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument("--tokenizer_name", required=True, type=str)
    parser.add_argument("--do_lower_case", action="store_true")
    parser.add_argument("--ctx_file", type=str, required=True, help="Path to passages set .tsv file")
    parser.add_argument("--num_processes_and_shards", "-num", type=int, default=None)
    parser.add_argument("--output_dir", required=True, type=str)
    parser.add_argument("--sequence_length", type=int, default=512)

    args = parser.parse_args()

    main(args)
