"""Dump pre-training data."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import logging
import json
import argparse
from tqdm import tqdm
import random
from unilm.dataset_utils import BinaryDataset

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
logger = logging.getLogger(__name__)


def create_masked_lm(
        args, context_tokens, selected_header_tokens_list, selected_value_tokens_list,
        tokenizer, header_index_list, value_index_list):
    if val_indexs is None:
        val_indexs = []
    if val_index_mask is None:
        val_index_mask = []
    col_indexs = []
    col_index_mask = []
    if match_labels is None:
        match_labels = []

    match_label_mask = [1] * len(match_labels)

    for h_tokens in selected_header_tokens:
        if len(tokens) + 1 + len(h_tokens) > max_seq_len:
            continue
        indexs = []
        for tk in h_tokens:
            if len(indexs) < max_tokens_per_cell:
                indexs.append(len(tokens))
                tokens.append(tk)
        tokens.append(express_token)
        index_mask = [1] * len(indexs)
        while len(indexs) < max_tokens_per_cell:
            indexs.append(0)
            index_mask.append(0)
        col_indexs.append(indexs)
        col_index_mask.append(index_mask)

    while len(val_indexs) < max_match_value_per_seq:
        val_indexs.append([0] * max_tokens_per_cell)
        val_index_mask.append([1] * max_tokens_per_cell)
        match_labels.append(0)
        match_label_mask.append(0)

    while len(col_indexs) < max_column_per_seq:
        col_indexs.append([0] * max_tokens_per_cell)
        col_index_mask.append([1] * max_tokens_per_cell)

    if print_info:
        logger.info("--------------------------------------------------------------------")
        for col_label, value_index in zip(match_labels, val_indexs):
            if max(value_index) > 0:
                value_tokens = []
                for index in value_index:
                    if index > 0:
                        value_tokens.append(tokens[index])
                logger.info("Use value [{}] to match col {}".format(" ".join(value_tokens), col_label))

        for column_id, column_index in enumerate(col_indexs):
            if max(column_index) > 0:
                column_tokens = []
                for index in column_index:
                    if index > 0:
                        column_tokens.append(tokens[index])
                logger.info("Column {}: {}".format(column_id, " ".join(column_tokens)))

    # mask
    cand_pos = []
    special_pos = set()
    for i, tk in enumerate(tokens):
        if tk not in donot_mask:
            cand_pos.append(i)
        else:
            special_pos.add(i)
    random.shuffle(cand_pos)

    n_pred = min(num_max_mask, max(0, int(round(len(cand_pos) * mask_prob))))

    masked_pos = cand_pos[:n_pred]
    masked_tokens = [tokens[_] for _ in masked_pos]
    masked_ids = tokenizer.convert_tokens_to_ids(masked_tokens)
    masked_weights = [1] * len(masked_tokens)

    for mask_pos in masked_pos:
        if random.random() < 0.8:    # 80%
            tokens[mask_pos] = MASK_TOKEN
        elif random.random() < 0.5:    # 10%
            token_id = random.randint(0, len(tokenizer.ids_to_tokens) - 1)
            tokens[mask_pos] = tokenizer.ids_to_tokens[token_id]

    while len(masked_weights) < num_max_mask:
        masked_pos.append(0)
        masked_ids.append(0)
        masked_weights.append(0)

    assert len(masked_pos) == num_max_mask
    assert len(masked_ids) == num_max_mask
    assert len(masked_weights) == num_max_mask

    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    segment_ids = segment_ids + [1] * (len(tokens) - len(segment_ids))
    input_mask = [1] * len(input_ids)
    while len(input_ids) < max_seq_len:
        input_ids.append(0)
        input_mask.append(0)
        segment_ids.append(0)

    assert len(input_ids) == max_seq_len
    assert len(input_mask) == max_seq_len
    assert len(segment_ids) == max_seq_len

    is_next_mask = 1 if is_random_next is not None else 0
    if is_random_next is None:
        is_random_next = 0
    is_match_mask = 1 if choose_raw_index is not None else 0

    if print_info:
        logger.info("unique_id: %s" % unique_id)
        logger.info("is_random_next: %d" % is_random_next)
        logger.info("tokens: %s" % " ".join(tokens))
        logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
        logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
        logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
        logger.info("masked_tokens: %s" % " ".join([str(x) for x in masked_tokens]))
        logger.info("masked_pos: %s" % " ".join([str(x) for x in masked_pos]))
        logger.info("masked_ids: %s" % " ".join([str(x) for x in masked_ids]))
        logger.info("masked_weights: %s" % " ".join([str(x) for x in masked_weights]))
        logger.info("match_labels: %s" % " ".join([str(x) for x in match_labels]))
        logger.info("match_label_mask: %s" % " ".join([str(x) for x in match_label_mask]))
        # logger.info("val_indexs: %s" % " ".join([str(x) for x in val_indexs]))
        # logger.info("val_index_mask: %s" % " ".join([str(x) for x in val_index_mask]))
        # logger.info("col_indexs: %s" % " ".join([str(x) for x in col_indexs]))
        # logger.info("col_index_mask: %s" % " ".join([str(x) for x in col_index_mask]))
        logger.info("is_next_mask: %s" % str(is_next_mask))
        logger.info("is_match_mask: %s" % str(is_match_mask))
        logger.info("--------------------------------------------------------------------")

    return InputFeatures(
        unique_id=unique_id, example_index=example_index, row_id=choose_raw_index,
        tokens=tokens, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids,
        is_random_next=is_random_next, masked_pos=masked_pos, masked_ids=masked_ids,
        masked_weights=masked_weights, match_labels=match_labels, val_indexs=val_indexs,
        val_index_mask=val_index_mask, col_indexs=col_indexs, col_index_mask=col_index_mask,
        match_label_mask=match_label_mask, is_next_mask=is_next_mask, is_match_mask=is_match_mask,
    )


def select_cell_index_list(cells, num_max_cells, shuffle=False):
    cell_index_list = list(range(len(cells)))
    if shuffle:
        random.shuffle(cell_index_list)
    if len(cells) > num_max_cells:
        cell_index_list = cell_index_list[:num_max_cells]
    return cell_index_list


def convert_example_to_features(args, example, random_example, tokenizer, express_token, empty_token):
    example_features = []

    table = example["table"]
    context_text_list = example["questions"]

    num_rows = len(table["rows"])

    if args.sample_all_rows:
        select_row_indexs = list(range(num_rows))
    else:
        select_row_indexs = [random.randint(0, num_rows - 1)]

    selected_header_cells = select_cell_index_list(table["columns"], args.max_column_per_seq, shuffle=False)
    selected_header_tokens_list = [table["columns"][index] for index in selected_header_cells]

    context_text_index_list = list(range(len(context_text_list)))
    random.shuffle(context_text_index_list)
    context_tokens = []
    for context_text_index in context_text_index_list:
        context_tokens.append(context_text_list[context_text_index])

    for row_index in select_row_indexs:
        select_row_cells = [_ for _ in selected_header_cells]
        random.shuffle(select_row_cells)
        select_row_cells = select_row_cells[:args.max_value_per_seq]
        select_row_tokens_list = [table["rows"][row_index][index] for index in select_row_cells]

        feature = create_masked_lm(

        )
        example_features.append(feature)
    return example_features


def get_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("--max_seq_length", default=256, type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. Sequences "
                             "longer than this will be truncated, and sequences shorter than this will be padded.")
    parser.add_argument("--max_context_length", default=128, type=int,
                        help="The maximum number of tokens for the question. Questions longer than this will "
                             "be truncated to this length.")
    parser.add_argument("--max_column_per_seq", default=20, type=int,
                        help="The maximum number of tokens for the question. Questions longer than this will "
                             "be truncated to this length.")
    parser.add_argument("--max_tokens_per_cell", default=32, type=int,
                        help="The maximum number of tokens for the question. Questions longer than this will "
                             "be truncated to this length.")
    parser.add_argument("--max_match_value_per_seq", default=20, type=int,
                        help="The maximum number of tokens for the question. Questions longer than this will "
                             "be truncated to this length.")
    parser.add_argument("--num_max_mask", default=76, type=int,
                        help="The maximum number of tokens for the question. Questions longer than this will "
                             "be truncated to this length.")
    parser.add_argument("--mask_prob", default=0.15, type=float,
                        help="The maximum number of tokens for the question. Questions longer than this will "
                             "be truncated to this length.")
    parser.add_argument("--min_question_length", default=0, type=int,
                        help="The maximum number of tokens for the question. Questions longer than this will "
                             "be truncated to this length.")

    parser.add_argument("--context2table", action='store_true', help="Context 2 tables loss")
    parser.add_argument("--value2col", action='store_true', help="value match column name loss")
    parser.add_argument("--combine_mode", action='store_true', help="[CLS] query [SEP] values [SEP] headers [SEP]")
    parser.add_argument("--pretraining_mode", action='store_true', help="[CLS] query [SEP] values [SEP] headers [SEP]")

    # Using multiple progress to create data
    parser.add_argument("--generation_id", default=0, type=int,
                        help="Total number of training steps to perform.")
    parser.add_argument("--generation_size", default=0, type=int,
                        help="Total number of training steps to perform.")
    parser.add_argument("--random_seed", default=0, type=int,
                        help="Total number of training steps to perform.")

    # for create / read training data
    parser.add_argument("--num_instances_per_file", default=256 * 5000, type=int,
                        help="num_steps_per_feature_file")
    parser.add_argument("--num_training_files", default=200, type=int,
                        help="Total number of training files. ")
    parser.add_argument("--data_type", default='h', type=str,
                        help="Data type for storing training features in python/array ")

    args = parser.parse_args()
    return args


def main():
    args = get_args()

    for feature_file_idx in tqdm(range(args.num_training_files)):
        if args.generation_size > 0 and feature_file_idx % args.generation_size != args.generation_id:
            continue
        if args.start_checkpoint > feature_file_idx:
            logger.info("Pass {} !".format(feature_file_idx))
        if args.write_buffer_dir is None:
            output_feature_file = \
                os.path.join(args.output_dir,
                             "features.{}.bin".format(feature_file_idx))
        else:
            output_feature_file = \
                os.path.join(args.write_buffer_dir,
                             "features.{}.bin".format(feature_file_idx))
            logger.info("Use buffer dir %s." % args.write_buffer_dir)

        output_feature_format_file = os.path.join(
            args.output_dir, "feature_format.{}.json".format(feature_file_idx))

        for _ in tqdm(range(args.num_instances_per_file)):
            instance = train_dataset.__getitem__(0)
            features = {
                "input_ids": instance[0],
                "num_tokens_a": instance[1],
                "num_tokens_b": instance[2],
                "masked_ids": instance[3],
                "masked_pos": instance[4],
                "masked_weights": instance[5],
                "is_next": instance[6],
                "task_idx": instance[7],
            }
            if args.permutate_mode:
                features["position_ids"] = instance[8]
                if args.separate_mask:
                    features["pseudo_ids"] = instance[9]
                    features["span_ids"] = instance[10]
                else:
                    features["token_type_ids"] = instance[9]
            try:
                data = reducer.binary(features=features)
                feature_writer.write(data)
            except Exception as e:
                logger.info("Meet an exception {}".format(str(e)))

        feature_writer.close()

        if args.write_buffer_dir is not None:
            with open(output_feature_file, mode="rb") as reader:
                features_data = reader.read()

            logger.info("Read buffer feature {} from {}".format(len(features_data), output_feature_file))

            target_output_feature_file = \
                os.path.join(args.output_dir,
                             "features.{}.bin".format(feature_file_idx))

            with open(target_output_feature_file, mode="wb") as writer:
                writer.write(features_data)

            logger.info("Write feature {} to {}".format(len(features_data), target_output_feature_file))
            del features_data

            with open(output_feature_file, mode="w") as writer:
                writer.write("Blobfuse please do not crash")

        reducer.save_config(output_feature_format_file)

        if feature_file_idx == 0:
            output_feature_format_file = os.path.join(
                args.output_dir, "feature_format.json")
            reducer.save_config(output_feature_format_file)


if __name__ == "__main__":
    main()
