import os
import json
import lmdb
import logging
import argparse
import time
from multiprocessing import Process, Queue

from unilm.tokenization_utils import get_tokenizer
from unilm.utils import write_to_lmdb, serialize_str

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
logger = logging.getLogger(__name__)


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_file', help='read data from',
                        type=str, default='/home/lidong1/data/share/cnn_dailymail/')
    parser.add_argument('--output_dir', help='save data to',
                        type=str, default='/mnt/data/cnn_dailymail/')
    parser.add_argument("--model_type", default=None, type=str,
                        help="Model type")
    parser.add_argument("--tokenizer_name", default=None, type=str,
                        help="tokenizer name")
    parser.add_argument("--do_lower_case", action='store_true',
                        help="Set this flag if you are using an uncased model.")
    default_process_count = max(1, os.cpu_count() - 1)
    parser.add_argument("--processes", type=int, default=default_process_count,
                        help="Number of processes to use (default %(default)s)")
    args = parser.parse_args()
    return args


def worker(input_queue, output_queue, process_id, tokenizer_name, model_type, do_lower_case):
    logger.info("Worker #%s start" % str(process_id))

    tokenizer = get_tokenizer(tokenizer_name, do_lower_case=do_lower_case, model_type=model_type)
    for func, args in iter(input_queue.get, 'STOP'):
        result = func(*args, tokenizer=tokenizer)
        output_queue.put(result)
    output_queue.put('STOP-#%s' % str(process_id))


def dump_result(output_dir, output_queue, num_process):
    logger.info("Dump process is starting !")
    cc = 0

    logger.info(" output_queue = {}".format(str(output_queue)))

    db = lmdb.open(output_dir, readonly=False, map_async=True)
    write_to_lmdb(db, b"__start__", serialize_str(0))

    while True:
        r = output_queue.get()
        if r.startswith('STOP'):
            num_process -= 1
            process_id = r[5:]
            logger.info("Worker #%s is done, there are %d processes still working !" % (process_id, num_process))
            if num_process <= 0:
                break
        else:
            write_to_lmdb(
                db, b"table_%d" % cc,
                serialize_str(r))

            if cc % 10000 == 0:
                logger.info("process data {}!".format(cc))
                logger.info("After preprocessing %s" % json.dumps(r, indent=2))

            cc += 1

    logger.info("process data {}!".format(cc))
    write_to_lmdb(db, b"__size__", serialize_str(cc))
    db.sync()
    db.close()


def do_tokenize_on_json(json_data, tokenizer):
    if isinstance(json_data, list):
        for value in json_data:
            do_tokenize_on_json(value, tokenizer)
    else:
        for key in json_data:
            value = json_data[key]
            if key == "text" or key == "originalText":
                json_data[key] = tokenizer.tokenize(value)
            elif not isinstance(value, str):
                do_tokenize_on_json(json_data[key], tokenizer)


def process_data(json_text, tokenizer):
    json_data = json.loads(json_text)
    do_tokenize_on_json(json_data, tokenizer)
    return json.dumps(json_data, indent=None)


def main():
    args = get_args()

    task_queue = Queue(10000)
    done_queue = Queue()
    num_process = args.processes
    input_file = args.input_file

    logger.info("Warmup workers")

    for i in range(num_process):
        Process(target=worker, args=(task_queue, done_queue, i, args.tokenizer_name, args.model_type, args.do_lower_case)).start()

    print("Start writer !")
    Process(target=dump_result, args=(args.output_dir, done_queue, num_process)).start()

    with open(input_file, mode="r", encoding="utf-8") as reader:
        def add_task(task_data):
            try:
                task_queue.put((process_data, (task_data,)), timeout=10)
            except Exception as e:
                time.sleep(10)
                logger.info("Workers are too slow !")

        for _line in reader:
            line = _line.strip()
            add_task(line)

    for i in range(num_process):
        task_queue.put('STOP')


if __name__ == '__main__':
    main()
