"""Dump pre-training data."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import logging
import json
import argparse
from tqdm import tqdm
import random
import collections
from unilm.tokenization_utils import get_tokenizer
from unilm.dataset_utils import BinaryDataset

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
logger = logging.getLogger(__name__)


CLS_TOKEN = "[CLS]"
SEP_TOKEN = "[SEP]"
MASK_TOKEN = "[MASK]"
donot_mask = set()
donot_mask.add(CLS_TOKEN)
donot_mask.add(SEP_TOKEN)


class TableColum(object):
    def __init__(self, name, values):
        self.name = name
        self.values = values

    def __len__(self):
        return len(self.values)

    def __str__(self):
        return self.__repr__()

    def __repr__(self):
        return "{}: {}".format(self.name, " ".join(self.values))


class TableExample(object):
    def __init__(self, table_id, context_text_list, columns):
        self.table_id = table_id
        self.context_text_list = context_text_list
        self.columns = columns

        self.context_tokens_list = []
        self.table_tokens = []

    def __str__(self):
        return self.__repr__()

    def __repr__(self):
        s = ""
        s += "table_id: %s" % self.table_id
        s += "\n, context_text_list: %s" % (
            " ".join(f"{context_class, context_text}" for context_class, context_text in self.context_text_list))
        for col in self.columns:
            s += '\n' + str(col)
        return s


class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, unique_id, example_index, row_id, tokens,
                             input_ids, input_mask, segment_ids, is_random_next,
                             masked_pos, masked_ids, masked_weights, is_next_mask,
                             match_labels, val_indexs, val_index_mask,
                             col_indexs, col_index_mask, match_label_mask, is_match_mask):
        self.unique_id = unique_id
        self.example_index = example_index
        self.row_id = row_id
        self.tokens = tokens
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.is_random_next = is_random_next
        self.masked_pos = masked_pos
        self.masked_ids = masked_ids
        self.masked_weights = masked_weights
        self.match_labels = match_labels
        self.val_indexs = val_indexs
        self.col_indexs = col_indexs
        self.match_label_mask = match_label_mask
        self.val_index_mask = val_index_mask
        self.col_index_mask = col_index_mask
        self.is_next_mask = is_next_mask
        self.is_match_mask = is_match_mask


def create_masked_lm(
        tokens, selected_header_tokens, max_seq_len, max_column_per_seq,
        max_match_value_per_seq, max_tokens_per_cell, express_token, num_max_mask,
        mask_prob, tokenizer, unique_id, print_info=False, is_random_next=None, choose_raw_index=None,
        example_index=None, val_indexs=None, val_index_mask=None, match_labels=None, segment_ids=None):
    assert express_token == SEP_TOKEN

    if val_indexs is None:
        val_indexs = []
    if val_index_mask is None:
        val_index_mask = []
    col_indexs = []
    col_index_mask = []
    if match_labels is None:
        match_labels = []

    if segment_ids is None:
        segment_ids = [0] * len(tokens)
    match_label_mask = [1] * len(match_labels)

    for h_tokens in selected_header_tokens:
        if len(tokens) + 1 + len(h_tokens) > max_seq_len:
            continue
        indexs = []
        for tk in h_tokens:
            if len(indexs) < max_tokens_per_cell:
                indexs.append(len(tokens))
                tokens.append(tk)
        tokens.append(express_token)
        index_mask = [1] * len(indexs)
        while len(indexs) < max_tokens_per_cell:
            indexs.append(0)
            index_mask.append(0)
        col_indexs.append(indexs)
        col_index_mask.append(index_mask)

    while len(val_indexs) < max_match_value_per_seq:
        val_indexs.append([0] * max_tokens_per_cell)
        val_index_mask.append([1] * max_tokens_per_cell)
        match_labels.append(0)
        match_label_mask.append(0)

    while len(col_indexs) < max_column_per_seq:
        col_indexs.append([0] * max_tokens_per_cell)
        col_index_mask.append([1] * max_tokens_per_cell)

    if print_info:
        logger.info("--------------------------------------------------------------------")
        for col_label, value_index in zip(match_labels, val_indexs):
            if max(value_index) > 0:
                value_tokens = []
                for index in value_index:
                    if index > 0:
                        value_tokens.append(tokens[index])
                logger.info("Use value [{}] to match col {}".format(" ".join(value_tokens), col_label))

        for column_id, column_index in enumerate(col_indexs):
            if max(column_index) > 0:
                column_tokens = []
                for index in column_index:
                    if index > 0:
                        column_tokens.append(tokens[index])
                logger.info("Column {}: {}".format(column_id, " ".join(column_tokens)))

    # mask
    cand_pos = []
    special_pos = set()
    for i, tk in enumerate(tokens):
        if tk not in donot_mask:
            cand_pos.append(i)
        else:
            special_pos.add(i)
    random.shuffle(cand_pos)

    n_pred = min(num_max_mask, max(0, int(round(len(cand_pos) * mask_prob))))

    masked_pos = cand_pos[:n_pred]
    masked_tokens = [tokens[_] for _ in masked_pos]
    masked_ids = tokenizer.convert_tokens_to_ids(masked_tokens)
    masked_weights = [1] * len(masked_tokens)

    for mask_pos in masked_pos:
        if random.random() < 0.8:    # 80%
            tokens[mask_pos] = MASK_TOKEN
        elif random.random() < 0.5:    # 10%
            token_id = random.randint(0, len(tokenizer.ids_to_tokens) - 1)
            tokens[mask_pos] = tokenizer.ids_to_tokens[token_id]

    while len(masked_weights) < num_max_mask:
        masked_pos.append(0)
        masked_ids.append(0)
        masked_weights.append(0)

    assert len(masked_pos) == num_max_mask
    assert len(masked_ids) == num_max_mask
    assert len(masked_weights) == num_max_mask

    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    segment_ids = segment_ids + [1] * (len(tokens) - len(segment_ids))
    input_mask = [1] * len(input_ids)
    while len(input_ids) < max_seq_len:
        input_ids.append(0)
        input_mask.append(0)
        segment_ids.append(0)

    assert len(input_ids) == max_seq_len
    assert len(input_mask) == max_seq_len
    assert len(segment_ids) == max_seq_len

    is_next_mask = 1 if is_random_next is not None else 0
    if is_random_next is None:
        is_random_next = 0
    is_match_mask = 1 if choose_raw_index is not None else 0

    if print_info:
        logger.info("unique_id: %s" % unique_id)
        logger.info("is_random_next: %d" % is_random_next)
        logger.info("tokens: %s" % " ".join(tokens))
        logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
        logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
        logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
        logger.info("masked_tokens: %s" % " ".join([str(x) for x in masked_tokens]))
        logger.info("masked_pos: %s" % " ".join([str(x) for x in masked_pos]))
        logger.info("masked_ids: %s" % " ".join([str(x) for x in masked_ids]))
        logger.info("masked_weights: %s" % " ".join([str(x) for x in masked_weights]))
        logger.info("match_labels: %s" % " ".join([str(x) for x in match_labels]))
        logger.info("match_label_mask: %s" % " ".join([str(x) for x in match_label_mask]))
        # logger.info("val_indexs: %s" % " ".join([str(x) for x in val_indexs]))
        # logger.info("val_index_mask: %s" % " ".join([str(x) for x in val_index_mask]))
        # logger.info("col_indexs: %s" % " ".join([str(x) for x in col_indexs]))
        # logger.info("col_index_mask: %s" % " ".join([str(x) for x in col_index_mask]))
        logger.info("is_next_mask: %s" % str(is_next_mask))
        logger.info("is_match_mask: %s" % str(is_match_mask))
        logger.info("--------------------------------------------------------------------")

    # return InputFeatures(
    #     unique_id=unique_id, example_index=example_index, row_id=choose_raw_index,
    #     tokens=tokens, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids,
    #     is_random_next=is_random_next, masked_pos=masked_pos, masked_ids=masked_ids,
    #     masked_weights=masked_weights, match_labels=match_labels, val_indexs=val_indexs,
    #     val_index_mask=val_index_mask, col_indexs=col_indexs, col_index_mask=col_index_mask,
    #     match_label_mask=match_label_mask, is_next_mask=is_next_mask, is_match_mask=is_match_mask,
    # )
    return {
        "input_ids": input_ids,
        "input_mask": input_mask,
        "segment_ids": segment_ids,
        "is_random_next": is_random_next,
        "masked_pos": masked_pos,
        "masked_ids": masked_ids,
        "masked_weights": masked_weights,
        "match_labels": match_labels,
        "val_indexs": val_indexs,
        "val_index_mask": val_index_mask,
        "col_indexs": col_indexs,
        "col_index_mask": col_index_mask,
        "match_label_mask": match_label_mask,
        "is_next_mask": is_next_mask,
        "is_match_mask": is_match_mask,
    }


def tokenize_for_examples(examples, tokenizer, max_seq_len, max_context_length, empty_token, value2col, use_tqdm):
    stat = collections.OrderedDict()
    col_stat = collections.defaultdict(int)
    cell_len_count = collections.defaultdict(int)
    header_len_stat = collections.defaultdict(int)
    start = 0
    interval = 32
    while start <= max_seq_len:
        stat[start // interval] = 0
        start += interval

    gen = tqdm(examples) if use_tqdm else examples
    for (example_index, example) in enumerate(gen):
        best_text = None
        for _, text in example.context_text_list:
            if best_text is None or len(best_text) < len(text):
                best_text = text
        if best_text is not None:
            choose_context = best_text
            context_tokens = tokenizer.tokenize(choose_context)

            if len(context_tokens) > max_context_length:
                context_tokens = context_tokens[0:max_context_length]
            example.context_tokens_list.append(context_tokens)
            stat[len(context_tokens) // interval] += 1

        for col in example.columns:
            col_values = []
            header_tokens = tokenizer.tokenize(col.name)
            if len(header_tokens) == 0:
                header_tokens.append(empty_token)
            header_len_stat[len(header_tokens)] += 1
            if value2col > 0:
                for value in col.values:
                    col_values.append(tokenizer.tokenize(value))
                    if len(col_values[-1]) == 0:
                        col_values[-1].append(empty_token)
                cell_len_count[len(col_values)] += 1
            example.table_tokens.append((header_tokens, col_values))

        col_stat[len(example.table_tokens)] += 1

    logger.info("Context token stat: ")
    logger.info(json.dumps(stat, indent=2))

    logger.info("Columns stat: ")
    logger.info(json.dumps(col_stat, indent=2))

    logger.info("Cell length stat: ")
    logger.info(json.dumps(cell_len_count, indent=2))


def read_table_and_question_examples_from_wiki_sql(input_file, pretraining_mode=False):
    examples = []

    with open(input_file, "r", encoding='utf-8') as reader:
        all_data = json.load(reader)

    noname_cc = 0
    for pairs in tqdm(all_data):
        raw_table = pairs["table"]

        for query in pairs["queries"]:
            columns = []
            num_values = -1
            for col_index, col_name in enumerate(raw_table["header"]):
                col_values = list(row[col_index] for row in raw_table["rows"])
                if num_values == -1:
                    num_values = len(col_values)
                else:
                    assert len(col_values) == num_values
                col_values = list(str(val) for val in col_values)
                columns.append(TableColum(name=col_name, values=col_values))

            if not pretraining_mode:
                context_text_list = [('question', query['question'])]
            else:
                context_text_list = []
            # context_text_list = []
            # for key in valid_keys:
            #     if key in raw_table:
            #         context_text_list = [(key, raw_table[key])]
            if len(context_text_list) == 0:
                noname_cc += 1
            examples.append(TableExample(
                table_id=len(examples),
                context_text_list=context_text_list,
                columns=columns))
        # if len(examples) >= 3200:
        #     logger.info("    ********    Only load {} examples for debug !    ********    ".format(len(examples)))
        #     break

    logger.info("Load {} examples from {}".format(len(examples), input_file))

    print("noname_cc = {}".format(noname_cc))

    return examples


def convert_examples_to_features(
        examples, tokenizer, max_seq_len, max_context_length,
        max_column_per_seq, max_match_value_per_seq, max_tokens_per_cell,
        first_epoch, mask_prob, num_max_mask, express_token, empty_token,
        use_tqdm=True, shuffle_headers=False, value2col=True,
        context2table=True, combine_mode=False):
    if first_epoch:
        logger.info(f"max_seq_len = {max_seq_len}")
        logger.info(f"max_context_length = {max_context_length}")
        logger.info(f"max_match_value_per_seq = {max_match_value_per_seq}")
        logger.info(f"max_column_per_seq = {max_column_per_seq}")
        logger.info(f"max_tokens_per_cell = {max_tokens_per_cell}")
        logger.info(f"mask_prob = {mask_prob}")
        logger.info(f"num_max_mask = {num_max_mask}")
        logger.info(f"express_token = {express_token}")
        logger.info(f"empty_token = {empty_token}")

    if combine_mode:
        assert value2col and context2table

    gen = tqdm(examples) if use_tqdm else examples
    epoch_features = []
    for (example_index, example) in enumerate(gen):
        assert isinstance(example, TableExample)
        selected_header_tokens = []
        selected_value_tokens = []
        choose_raw_index = random.randint(0, len(example.columns[0].values) - 1)
        if shuffle_headers:
            random.shuffle(example.table_tokens)
        for table_header_tokens, table_value_tokens in example.table_tokens:
            selected_header_tokens.append(table_header_tokens)
            if value2col:
                selected_value_tokens.append(table_value_tokens[choose_raw_index])

            if len(selected_header_tokens) >= max_column_per_seq:
                break

        is_random_next = 0

        if context2table:
            context_tokens = random.choice(example.context_tokens_list)

            tokens_c2t = [CLS_TOKEN] + context_tokens + [SEP_TOKEN]

            if not combine_mode:
                feature = create_masked_lm(
                    tokens_c2t, selected_header_tokens, max_seq_len, max_column_per_seq, max_match_value_per_seq,
                    max_tokens_per_cell, express_token, num_max_mask, mask_prob, tokenizer,
                    print_info=example_index < 20 and first_epoch, unique_id=str(example_index) + ':c2t',
                    is_random_next=is_random_next, example_index=example_index)
                epoch_features.append(feature)

        if value2col:
            val_indexs = []
            val_index_mask = []
            match_labels = []

            candidate_vals = []
            candidate_cols = []
            num_val_tokens = 0
            num_col_tokens = 0
            delta = 1
            if combine_mode and context2table:
                delta = len(tokens_c2t)

            for col_id, (h_tokens, v_tokens) in enumerate(zip(selected_header_tokens, selected_value_tokens)):
                exp_len = 1 + len(h_tokens)
                exp_len += 1 + len(v_tokens)
                if exp_len + num_col_tokens + num_val_tokens + delta < max_seq_len:
                    num_col_tokens += 1 + len(h_tokens)
                    candidate_cols.append((col_id, h_tokens))
                    num_val_tokens += 1 + len(v_tokens)
                    candidate_vals.append((col_id, v_tokens))
                # else:
                #     logger.info('\n')
                #     logger.info("Skip header tokens: {}, value tokens: {}".format(" ".join(h_tokens), " ".join(v_tokens)))

            if len(candidate_cols) < 2:
                # logger.info('\n')
                # logger.info("Skipp the example {}".format(example.context_text_list[-1][1]))
                continue

            tokens_v2c = [CLS_TOKEN] if not combine_mode else tokens_c2t
            assert len(candidate_cols) == len(candidate_vals)
            random.shuffle(candidate_vals)
            for col_id, v_tokens in candidate_vals[:max_match_value_per_seq]:
                m_label = None
                for i in range(len(candidate_cols)):
                    if col_id == candidate_cols[i][0]:
                        m_label = i
                assert m_label is not None
                match_labels.append(m_label)
                indexs = []
                for tk in v_tokens:
                    if len(indexs) < max_tokens_per_cell:
                        indexs.append(len(tokens_v2c))
                        tokens_v2c.append(tk)
                tokens_v2c.append(express_token)
                index_mask = [1] * len(indexs)
                while len(indexs) < max_tokens_per_cell:
                    indexs.append(0)
                    index_mask.append(0)
                val_indexs.append(indexs)
                val_index_mask.append(index_mask)

            selected_header_tokens = [tokens for _, tokens in candidate_cols]

            if not combine_mode:
                feature = create_masked_lm(
                    tokens_v2c, selected_header_tokens, max_seq_len, max_column_per_seq, max_match_value_per_seq,
                    max_tokens_per_cell, express_token, num_max_mask, mask_prob, tokenizer,
                    print_info=example_index < 20 and first_epoch, unique_id=str(example_index) + ':v2t',
                    choose_raw_index=choose_raw_index, example_index=example_index, val_indexs=val_indexs,
                    val_index_mask=val_index_mask, match_labels=match_labels)
                epoch_features.append(feature)
            else:
                feature = create_masked_lm(
                    tokens_v2c, selected_header_tokens, max_seq_len, max_column_per_seq, max_match_value_per_seq,
                    max_tokens_per_cell, express_token, num_max_mask, mask_prob, tokenizer,
                    print_info=example_index < 20 and first_epoch, unique_id=str(example_index) + ':combine',
                    choose_raw_index=choose_raw_index, example_index=example_index, val_indexs=val_indexs,
                    val_index_mask=val_index_mask, match_labels=match_labels, is_random_next=is_random_next)
                epoch_features.append(feature)

    return epoch_features


def get_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("--model_type", default=None, type=str,
                        help="Model type")
    parser.add_argument("--tokenizer_name", default=None, type=str,
                        help="tokenizer name")
    parser.add_argument("--do_lower_case", action='store_true',
                        help="Set this flag if you are using an uncased model.")

    parser.add_argument("--output_dir", default=None, type=str,
                        help="write_buffer_dir_for_blobfuse")
    parser.add_argument("--write_buffer_dir", default=None, type=str,
                        help="write_buffer_dir_for_blobfuse")
    parser.add_argument("--data_file", type=str,
                        help="Path to the data file")

    parser.add_argument("--max_seq_length", default=256, type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. Sequences "
                             "longer than this will be truncated, and sequences shorter than this will be padded.")
    parser.add_argument("--max_context_length", default=128, type=int,
                        help="The maximum number of tokens for the question. Questions longer than this will "
                             "be truncated to this length.")
    parser.add_argument("--max_column_per_seq", default=20, type=int,
                        help="The maximum number of tokens for the question. Questions longer than this will "
                             "be truncated to this length.")
    parser.add_argument("--max_tokens_per_cell", default=32, type=int,
                        help="The maximum number of tokens for the question. Questions longer than this will "
                             "be truncated to this length.")
    parser.add_argument("--max_match_value_per_seq", default=20, type=int,
                        help="The maximum number of tokens for the question. Questions longer than this will "
                             "be truncated to this length.")
    parser.add_argument("--num_max_mask", default=76, type=int,
                        help="The maximum number of tokens for the question. Questions longer than this will "
                             "be truncated to this length.")
    parser.add_argument("--mask_prob", default=0.15, type=float,
                        help="The maximum number of tokens for the question. Questions longer than this will "
                             "be truncated to this length.")
    parser.add_argument("--min_question_length", default=0, type=int,
                        help="The maximum number of tokens for the question. Questions longer than this will "
                             "be truncated to this length.")

    parser.add_argument("--context2table", action='store_true', help="Context 2 tables loss")
    parser.add_argument("--value2col", action='store_true', help="value match column name loss")
    parser.add_argument("--combine_mode", action='store_true', help="[CLS] query [SEP] values [SEP] headers [SEP]")
    parser.add_argument("--pretraining_mode", action='store_true', help="[CLS] query [SEP] values [SEP] headers [SEP]")

    parser.add_argument("--use_tqdm", action='store_true', help="Whether not to use CUDA when available")

    # Using multiple progress to create data
    parser.add_argument("--generation_id", default=0, type=int,
                        help="Total number of training steps to perform.")
    parser.add_argument("--generation_size", default=0, type=int,
                        help="Total number of training steps to perform.")
    parser.add_argument("--random_seed", default=0, type=int,
                        help="Total number of training steps to perform.")

    # for create / read training data
    parser.add_argument("--num_instances_per_file", default=32 * 10000, type=int,
                        help="num_steps_per_feature_file")
    parser.add_argument("--num_training_files", default=20, type=int,
                        help="Total number of training files. ")
    parser.add_argument("--data_type", default='h', type=str,
                        help="Data type for storing training features in python/array ")

    args = parser.parse_args()
    return args


class Generator(object):

    def __init__(self, examples, tokenizer, args):
        self.args = args
        self.examples = examples
        self.tokenizer = tokenizer

    def __iter__(self):
        args = self.args

        num_epoch = 0
        while True:
            logger.info("Generate epoch-%d:" % num_epoch)
            train_features = convert_examples_to_features(
                examples=self.examples, tokenizer=self.tokenizer, max_seq_len=args.max_seq_length,
                max_context_length=args.max_context_length, max_column_per_seq=args.max_column_per_seq,
                express_token='[SEP]', max_tokens_per_cell=args.max_tokens_per_cell, combine_mode=args.combine_mode,
                max_match_value_per_seq=args.max_match_value_per_seq, mask_prob=args.mask_prob,
                num_max_mask=args.num_max_mask, empty_token='[unused101]', use_tqdm=args.use_tqdm,
                first_epoch=num_epoch == 0, context2table=args.context2table, value2col=args.value2col)
            num_epoch += 1
            for feature in train_features:
                yield feature


def main():
    args = get_args()
    tokenizer = get_tokenizer(args.tokenizer_name, model_type=args.model_type, do_lower_case=args.do_lower_case)

    train_examples = read_table_and_question_examples_from_wiki_sql(input_file=args.data_file)

    random.shuffle(train_examples)

    tokenize_for_examples(
      train_examples, tokenizer, args.max_seq_length, args.max_context_length,
      '[unused101]', args.value2col, args.use_tqdm)

    reducer = BinaryDataset()
    generator = iter(Generator(examples=train_examples, tokenizer=tokenizer, args=args))

    os.makedirs(args.output_dir, exist_ok=True)

    for feature_file_idx in tqdm(range(args.num_training_files)):
        if args.generation_size > 0 and feature_file_idx % args.generation_size != args.generation_id:
            continue
        if args.write_buffer_dir is None:
            output_feature_file = \
                os.path.join(args.output_dir,
                             "features.{}.bin".format(feature_file_idx))
        else:
            output_feature_file = \
                os.path.join(args.write_buffer_dir,
                             "features.{}.bin".format(feature_file_idx))
            logger.info("Use buffer dir %s." % args.write_buffer_dir)

        output_feature_format_file = os.path.join(
            args.output_dir, "feature_format.{}.json".format(feature_file_idx))
        if os.path.isfile(output_feature_format_file):
            continue
        feature_writer = open(output_feature_file, mode="wb")

        for _ in tqdm(range(args.num_instances_per_file)):
            instance = next(generator)
            features = instance
            # if args.permutate_mode:
            #     features["position_ids"] = instance[8]
            #     if args.separate_mask:
            #         features["pseudo_ids"] = instance[9]
            #         features["span_ids"] = instance[10]
            #     else:
            #         features["token_type_ids"] = instance[9]
            try:
                data = reducer.binary(features=features)
                feature_writer.write(data)
            except Exception as e:
                logger.info("Meet an exception {}".format(str(e)))

        feature_writer.close()

        if args.write_buffer_dir is not None:
            with open(output_feature_file, mode="rb") as reader:
                features_data = reader.read()

            logger.info("Read buffer feature {} from {}".format(len(features_data), output_feature_file))

            target_output_feature_file = \
                os.path.join(args.output_dir,
                             "features.{}.bin".format(feature_file_idx))

            with open(target_output_feature_file, mode="wb") as writer:
                writer.write(features_data)

            logger.info("Write feature {} to {}".format(len(features_data), target_output_feature_file))
            del features_data

            with open(output_feature_file, mode="w") as writer:
                writer.write("Blobfuse please do not crash")

        reducer.save_config(output_feature_format_file)

        if feature_file_idx == 0:
            output_feature_format_file = os.path.join(
                args.output_dir, "feature_format.json")
            reducer.save_config(output_feature_format_file)


if __name__ == "__main__":
    main()
