from argparse import ArgumentParser
from collections import OrderedDict
import numpy as np
from stanza.server import CoreNLPClient
import random
from tqdm import tqdm
import scipy.stats

from preprocess.annotation.read import read_table_anno
from preprocess.message_utils import AnnoMessage
from preprocess.structure import table_to_struc
from preprocess.structure_nsf import process
from preprocess.complete import *
from qa.datadump.analyze import *


AGGR_FOR_NEW_ANSWER = {
    'sum', 'average', 'diff', 'div', 'max', 'min', 'range', 'counta', 'opp'
}
AGGR_IN_TAPAS = {
    'none': 0, 'sum': 1, 'average': 2, 'counta': 3
}


def post_process_linked_data_cells(linked_cells):
    """ Post process linked cells.
    (1) Cell object -> cell string;
    (2) remove redundant data cells heuristically to ensure a rectangle area."""

    # Cell object -> cell string
    for cell_type, cells in linked_cells.items():
        for coord, cell in cells.items():
            cells[coord] = cell.cell_string

    data_cells = linked_cells['data']
    if len(data_cells) == 0:
        return linked_cells

    # remove redundant cells
    row_map, column_map = {}, {}
    for coord, literal in data_cells.items():
        row, col = coord
        row_map.setdefault(row, []).append(coord)
        column_map.setdefault(col, []).append(coord)
    row_list = sorted(row_map.items())
    len_row = len(row_list[0][1])
    non_rectangle_flag = False
    for i in range(1, len(row_list)):
        if len(row_list[i][1]) != len_row:
            non_rectangle_flag = True
            break
    if non_rectangle_flag:
        if 1 in column_map:  # a span leaf header, like 'percent'
            for coord in column_map[1]:
                if isinstance(naive_str_to_float(data_cells[coord]), str):
                    linked_cells['data'].pop(coord)
                    break

    return linked_cells


def check_valid_sample(sample_dict):
    """ Assert sample is valid, currently just check non-empty. """
    for k, v in sample_dict.items():
        if v is None or v == '':
            return False
        if k == 'question' and v.startswith('='):
            return False
        if k == 'answer' and (len(sample_dict[k]) < 1 or sample_dict[k][0] is None):
            return False
        if k == 'aggregation' and len(sample_dict[k]) < 1:
            return False
        if k == 'answer_cells' and len(sample_dict[k]) < 1:
            return False
        if k == 'linked_cells':
            if len(v['corner']) == 0 \
                and len(v['data']) == 0 \
                and len(v['left']) == 0 \
                and len(v['top']) == 0:
                return False
    return True


def normalize_sample(sample_dict):
    """ Normalize strings in sample dict. Only answer and answer_cells can be non-str."""
    sample_dict['desc_sent'] = normalize(sample_dict['desc_sent'])
    sample_dict['proc_sub_sent'] = normalize(sample_dict['proc_sub_sent'])
    sample_dict['key_part'] = normalize(sample_dict['key_part'])
    sample_dict['question'] = normalize(sample_dict['question'])

    for i in range(len(sample_dict['answer'])):
        if isinstance(sample_dict['answer'][i], str):
            sample_dict['answer'][i] = naive_str_to_float(sample_dict['answer'][i])
    for i in range(len(sample_dict['aggregation'])):
        sample_dict['aggregation'][i] = normalize(sample_dict['aggregation'][i])

    schema_link = {}
    for phrase, links in sample_dict['schema_link'].items():
        phrase = normalize(str(phrase))  # assert all phrases are str
        schema_link[phrase] = {}
        for coord, literal in links.items():
            schema_link[phrase][str(coord)] = normalize(str(literal))  # coord: (0, 1) -> '(0, 1)'
    sample_dict['schema_link'] = schema_link

    answer_cells = {}
    for coord, literal in sample_dict['answer_cells'].items():
        if isinstance(literal, str):
            literal = naive_str_to_float(literal)
        answer_cells[str(coord)] = literal  # coord: (0, 1) -> '(0, 1)'
    sample_dict['answer_cells'] = answer_cells

    linked_cells = {}
    for cell_type, cells in sample_dict['linked_cells'].items():
        new_cells = {}
        for coord, literal in cells.items():
            if isinstance(literal, str):
                literal = naive_str_to_float(literal)
            new_cells[str(coord)] = literal
        linked_cells[cell_type] = new_cells
    sample_dict['linked_cells'] = linked_cells


def to_qa_samples(parsed_dataset):
    """ From parsed dataset to qa sample format."""
    samples, global_id = [], 0
    table_dict = load_tables(os.path.join(args.root_dir, args.data_dir), args.html_dir)
    for table_id, blocks in parsed_dataset.items():
        ic(table_id)
        if len(table_id.split('_')) != 1:  # nsf
            nsf_table_order, page_id, file_name = table_id.split('_')
            nsf_annotation_file_path = os.path.join(args.root_dir, args.data_dir,
                                                    args.nsf_annotated_dir, page_id, f"{file_name}.xlsx")
            structure = process(nsf_annotation_file_path, file_name, page_id, nsf_table_order, structure_only=True)
        else:
            table = table_dict[table_id]
            structure = table_to_struc(table)
        for block in blocks:
            if any([isinstance(m, AnnoMessage) for m in block['sub_messages']]):
                continue
            for sample in block['sub_samples']:
                try:
                    linked_cells = complete_linking(structure, sample)
                    linked_cells = post_process_linked_data_cells(linked_cells)
                except Exception as e:
                    print(f'Error when linking cells: {e}')
                sample_dict = OrderedDict()
                sample_dict['id'] = str(global_id)
                sample_dict['context'] = str(table_id)
                sample_dict['desc_id'] = str(block['table_desc_sent_id'])  # TODO: debug, to remove
                sample_dict['question'] = sample['question']
                sample_dict['answer'] = sample['answer']
                sample_dict['schema_link'] = sample['schema_link']
                sample_dict['aggregation'] = sample['aggregation']
                sample_dict['answer_cells'] = sample['answer_cells']
                sample_dict['linked_cells'] = linked_cells
                sample_dict['desc_sent'] = block['table_desc_sent']  # TODO: debug, to remove
                sample_dict['proc_sub_sent'] = sample['sub_sent']  # TODO: debug, to remove
                sample_dict['key_part'] = str(sample['key_part'])  # TODO: debug, to remove
                if check_valid_sample(sample_dict):
                    normalize_sample(sample_dict)
                    global_id += 1
                    samples.append(sample_dict)
                else:
                    print(f"Found invalid cells in tid:{table_id}, did:{block['table_desc_sent_id']}")
    return samples


def extract_table_attrs(table_url_path, target='table'):
    """ Read table2url pairs to ensure tables on the same target must be in either train/dev/test set."""
    wb = load_workbook(table_url_path)
    sheet_names = wb.get_sheet_names()
    ws = wb.get_sheet_by_name(sheet_names[0])

    table2url, url2table = {}, {}
    table2domain, domain2table = {}, {}
    row = 2
    while ws.cell(row, 1).value is not None:
        if '--' in ws.cell(row, 1).value:  # skip invalid tables
            row += 1
            continue
        url2table.setdefault(ws.cell(row, 4).value, []).append(str(ws.cell(row, 2).value))
        table2url[str(ws.cell(row, 2).value)] = ws.cell(row, 4).value
        domain2table.setdefault(ws.cell(row, 5).value, []).append(str(ws.cell(row, 2).value))
        table2domain[str(ws.cell(row, 2).value)] = ws.cell(row, 5).value
        row += 1
    if target == 'page':
        return url2table, table2url
    elif target == 'domain':
        return domain2table, table2domain
    else:
        raise NotImplementedError("Only table2url and table2domain can be extracted.")


def fair_num_distribution(train_table_ids, dev_table_ids, test_table_ids,
                          cnt_train_on_target, cnt_dev_on_target, cnt_test_on_target,
                          max_diff_table_on_target=0.3, max_diff_table=50):
    """ Check if the split assigns fair number of tables on target, and number of samples. """
    ic(len(train_table_ids), cnt_train_on_target, len(train_table_ids) / cnt_train_on_target)
    ic(len(dev_table_ids), cnt_dev_on_target, len(dev_table_ids) / cnt_dev_on_target)
    ic(len(test_table_ids), cnt_test_on_target, len(test_table_ids) / cnt_test_on_target)

    if len(train_table_ids) == 0 or len(dev_table_ids) == 0 or len(test_table_ids) == 0:
        return False
    if abs(len(train_table_ids) / cnt_train_on_target - len(dev_table_ids) / cnt_dev_on_target) < max_diff_table_on_target \
            and abs(len(train_table_ids) / cnt_train_on_target - len(test_table_ids) / cnt_test_on_target) < max_diff_table_on_target \
            and abs(len(dev_table_ids) / cnt_dev_on_target - len(test_table_ids) / cnt_test_on_target) < max_diff_table_on_target \
            and abs(len(dev_table_ids) - len(test_table_ids)) < max_diff_table:
        return True
    return False


def fair_aggr_type_distribution(train_samples, dev_samples, test_samples,
                                max_dist=0.1):
    """ Check if the split assigns fair aggregation types."""
    _, train_aggr_distribution = aggregation_type_distribution(train_samples)
    _, dev_aggr_distribution = aggregation_type_distribution(dev_samples)
    _, test_aggr_distribution = aggregation_type_distribution(test_samples)
    from preprocess import _OPERATORS
    op_list = list(_OPERATORS)
    train_vec = get_distribution_vector(train_aggr_distribution, op_list)
    dev_vec = get_distribution_vector(dev_aggr_distribution, op_list)
    test_vec = get_distribution_vector(test_aggr_distribution, op_list)
    if L1_distance(train_vec, dev_vec) < max_dist \
        and L1_distance(train_vec, test_vec) < max_dist \
        and L1_distance(dev_vec, test_vec) < max_dist:
        return True
    return False


def get_distribution_vector(distribution, op_list):
    vec = [0.0 for _ in range(len(op_list))]
    for k, v in distribution.items():
        vec[op_list.index(k)] = v
    return np.array(vec)


def KL_divergence(p,q):
    return scipy.stats.entropy(p, q)


def L1_distance(p, q):
    return np.mean(np.abs(p - q))


def split_dataset(samples, method='table'):
    """ Split dataset into train/dev/test by 70:15:15. """
    parsed_tables = set()
    for sample in samples:
        parsed_tables.add(sample['context'])
    prop_train_table, prop_dev_table, prop_test_table = 0.7, 0.15, 0.15
    num_train_table = int(len(parsed_tables) * prop_train_table)
    num_dev_table = int(len(parsed_tables) * prop_dev_table)
    num_test_table = len(parsed_tables) - num_dev_table - num_train_table
    parsed_tables = sorted(list(parsed_tables))
    if method == 'page':
        table_url_path = os.path.join(args.root_dir, args.data_dir, args.table_url_file)
        url2table, table2url = extract_table_attrs(table_url_path, 'page')
        table2target, target2table = table2url, url2table
    elif method == 'table':
        table2target = {tid: tid for tid in parsed_tables}  # dummy, to fit 'page' and 'domain' pipeline
        target2table = {tid: [tid] for tid in parsed_tables}
    elif method == 'domain':
        table_url_path = os.path.join(args.root_dir, args.data_dir, args.table_url_file)
        domain2table, table2domain = extract_table_attrs(table_url_path, 'domain')
        table2target, target2table = table2domain, domain2table
    else:
        raise NotImplementedError("Only support split by ['page', 'table', 'domain'].")

    print(f"Split by {method}.")
    # random.seed(args.seed)
    # cnt_train_table, cnt_dev_table, cnt_test_table = 0, 0, 0
    # cnt_train_on_target, cnt_dev_on_target, cnt_test_on_target = 0, 0, 0
    # train_table_ids, dev_table_ids, test_table_ids = [], [], []
    # i = 0
    # while i < len(parsed_tables):
    #     table_id = parsed_tables[i]
    #     if table_id in (train_table_ids + dev_table_ids + test_table_ids):
    #         i += 1
    #         continue
    #     same_target_table_ids = target2table[table2target[table_id]]
    #     same_target_table_ids = list(set(same_target_table_ids).intersection(parsed_tables))
    #     dice = random.random()  # maybe loop util in valid 'if'
    #     if 0 <= dice < prop_train_table and cnt_train_table < num_train_table:
    #         train_table_ids.extend(same_target_table_ids)
    #         cnt_train_table += len(same_target_table_ids)
    #         cnt_train_on_target += 1
    #         i += 1
    #     elif prop_train_table <= dice < prop_train_table + prop_dev_table and cnt_dev_table < num_dev_table:
    #         dev_table_ids.extend(same_target_table_ids)
    #         cnt_dev_table += len(same_target_table_ids)
    #         cnt_dev_on_target += 1
    #         i += 1
    #     elif prop_train_table + prop_dev_table <= dice < 1 and cnt_test_table < num_test_table:
    #         test_table_ids.extend(same_target_table_ids)
    #         cnt_test_table += len(same_target_table_ids)
    #         cnt_test_on_target += 1
    #         i += 1
    #
    # ic(len(train_table_ids), cnt_train_on_target)
    # ic(len(train_table_ids) / cnt_train_on_target)
    # ic(len(dev_table_ids), cnt_dev_on_target)
    # ic(len(dev_table_ids) / cnt_dev_on_target, cnt_dev_on_target)
    # ic(len(test_table_ids), cnt_test_on_target)
    # ic(len(test_table_ids) / cnt_test_on_target, cnt_test_on_target)

    for seed in range(0, 10000):
        random.seed(seed)
        ic(seed)
        train_table_ids, dev_table_ids, test_table_ids = [], [], []
        cnt_train_table, cnt_dev_table, cnt_test_table = 0, 0, 0
        cnt_train_on_target, cnt_dev_on_target, cnt_test_on_target = 0, 0, 0
        i = 0
        while i < len(parsed_tables):
            table_id = parsed_tables[i]
            if table_id in (train_table_ids + dev_table_ids + test_table_ids):
                i += 1
                continue
            same_target_table_ids = target2table[table2target[table_id]]
            same_target_table_ids = list(set(same_target_table_ids).intersection(parsed_tables))
            dice = random.random()
            if 0 <= dice < prop_train_table and cnt_train_table < num_train_table:
                train_table_ids.extend(same_target_table_ids)
                cnt_train_table += len(same_target_table_ids)
                cnt_train_on_target += 1
                i += 1
            elif prop_train_table <= dice < prop_train_table + prop_dev_table and cnt_dev_table < num_dev_table:
                dev_table_ids.extend(same_target_table_ids)
                cnt_dev_table += len(same_target_table_ids)
                cnt_dev_on_target += 1
                i += 1
            elif prop_train_table + prop_dev_table <= dice < 1 and cnt_test_table < num_test_table:
                test_table_ids.extend(same_target_table_ids)
                cnt_test_table += len(same_target_table_ids)
                cnt_test_on_target += 1
                i += 1

        if not fair_num_distribution(train_table_ids, dev_table_ids, test_table_ids,
                                     cnt_train_on_target, cnt_dev_on_target, cnt_test_on_target,
                                     max_diff_table_on_target=args.max_diff_table_on_target,
                                     max_diff_table=args.max_diff_table):
            print("Not fair number of tables on target, and number of samples.")
            continue

        train_samples, dev_samples, test_samples = [], [], []
        for sample in samples:
            if sample['context'] in train_table_ids:
                tmp_samples = train_samples
            elif sample['context'] in dev_table_ids:
                tmp_samples = dev_samples
            elif sample['context'] in test_table_ids:
                tmp_samples = test_samples
            tmp_samples.append(sample)

        # if not fair_aggr_type_distribution(train_samples, dev_samples, test_samples,
        #                                    max_dist=args.max_aggr_type_dist):
        #     print("Not fair split of aggregation types.")
        #     continue
        #
        # if len(train_samples) < len(train_samples + dev_samples + test_samples) * prop_train_table:
        #     print("Not enough train samples.")`
        #     continue

        return train_samples, dev_samples, test_samples

    raise ValueError("Available split strategy not found.")


def label_dataset(samples, client):
    """ Label question tokens with corenlp."""
    for sample in samples:
        annotation = client.annotate(sample['question'])
        tokens, lemma_tokens, pos_tags, ner_tags, ner_vals = [], [], [], [], []
        for sentence in annotation.sentence:
            for token in sentence.token:
                tokens.append(token.word)
                lemma_tokens.append(token.lemma)
                pos_tags.append(token.pos)
                ner_tags.append(token.ner)
                ner_vals.append(token.normalizedNER)
        sample['tokens'] = '|'.join(tokens)
        sample['lemma_tokens'] = '|'.join(lemma_tokens)
        sample['pos_tags'] = '|'.join(pos_tags)
        sample['ner_tags'] = '|'.join(ner_tags)
        sample['ner_vals'] = '|'.join(ner_vals)


def dump_to_tsv(samples, dump_path):
    """ Dump train/dev/test samples as *_meta.tsv file."""
    id, context, desc_id, question, answer = [], [], [], [], []
    schema_link, aggregation, answer_cells, linked_cells = [], [], [], []
    desc_sent, proc_sub_sent, key_part = [], [], []
    tokens, lemma_tokens, pos_tags, ner_tags, ner_vals = [], [], [], [], []
    for sample in samples:
        id.append(sample['id'])
        context.append(sample['context'])
        desc_id.append(sample['desc_id'])
        question.append(sample['question'])
        answer.append(sample['answer'])
        schema_link.append(sample['schema_link'])
        aggregation.append(sample['aggregation'])
        answer_cells.append(sample['answer_cells'])
        linked_cells.append(sample['linked_cells'])
        desc_sent.append(sample['desc_sent'])
        proc_sub_sent.append(sample['proc_sub_sent'])
        key_part.append(sample['key_part'])
        tokens.append(sample['tokens'])
        lemma_tokens.append(sample['lemma_tokens'])
        pos_tags.append(sample['pos_tags'])
        ner_tags.append(sample['ner_tags'])
        ner_vals.append(sample['ner_vals'])
    data = dict(
        id=pd.Series(id),
        context=pd.Series(context),
        desc_id=pd.Series(desc_id),
        question=pd.Series(question),
        answer=pd.Series(answer),
        schema_link=pd.Series(schema_link),
        aggregation=pd.Series(aggregation),
        answer_cells=pd.Series(answer_cells),
        linked_cells=pd.Series(linked_cells),
        desc_sent=pd.Series(desc_sent),
        proc_sub_sent=pd.Series(proc_sub_sent),
        key_part=pd.Series(key_part),
        tokens=pd.Series(tokens),
        lemma_tokens=pd.Series(lemma_tokens),
        pos_tags=pd.Series(pos_tags),
        ner_tags=pd.Series(ner_tags),
        ner_vals=pd.Series(ner_vals)
    )

    df = pd.DataFrame(data)
    df.to_csv(dump_path, sep='\t', index=False)


def dump_to_json(samples, dump_path):
    """ Dump train/dev/test meta to .json file. """
    with open(dump_path, 'w') as f:
        for sample in samples:
            f.write(json.dumps(sample))
            f.write('\n')


def dump_to_tapas_tables(output_table_dir):
    """ Dump .xlsx table into .csv file."""
    table_dir = os.path.join(args.root_dir, args.data_dir, args.valid_table_dir)
    table_files = os.listdir(table_dir)
    for file in table_files:
        table_file_path = os.path.join(table_dir, file)
        with open(table_file_path, 'r') as f:
            table = json.load(f)
        matrix = table['matrix']['Texts']  # TODO: maybe using duplicate in merge cells is better
        merged_regions = table['matrix']['MergedRegions']
        df = pd.DataFrame(matrix[1:])
        headers = matrix[0]
        for i in range(len(headers)):
            if headers[i] == '':
                for mr in merged_regions:
                    if mr['FirstRow'] == 0 and mr['FirstColumn'] <= i <= mr['LastColumn']:
                        headers[i] = headers[mr['FirstColumn']]
                if headers[i] == '':
                    headers[i] = '<COLUMN>_' + str(i)
        df.columns = headers
        df.to_csv(os.path.join(output_table_dir, file.split('.')[0] + '.csv'), index=False)


def dump_to_tapas_samples(samples, dump_path, supervise=False):
    """ Dump meta samples into .tsv file."""
    id, question, table_file, answer_coordinates, answer_text = [], [], [], [], []
    float_answer, aggregation_labels = [], []
    for sample in samples:
        id.append(sample['id'])
        question.append(sample['question'])
        table_file.append(sample['context']+'.csv')
        curr_answer_coordinates, curr_answer_text = [], []
        for coord, literal in sample['answer_cells'].items():
            coord = eval(coord)  # '(0, 1)' -> (0, 1)
            row_idx, col_idx = coord[0] - 1, coord[1]
            curr_answer_coordinates.append(str((row_idx, col_idx)))
            curr_answer_text.append(str(literal))
        answer_coordinates.append(curr_answer_coordinates)
        answer_text.append(curr_answer_text)
        if not supervise:
            float_flag = False
            if set(sample['aggregation']) & AGGR_FOR_NEW_ANSWER:
                answer = sample['answer'][0]
                if isinstance(answer, int) or isinstance(answer, float):
                    float_flag = True
                elif isinstance(answer, str):
                    answer = naive_str_to_float(answer)
                    if isinstance(answer, float):
                        float_flag = True
            if float_flag:
                float_answer.append(float(answer))
            else:
                float_answer.append(np.nan)
        else:
            aggr_flag = False
            for aggr in sample['aggregation']:
                if aggr in AGGR_IN_TAPAS:
                    aggr_flag = True
                    aggregation_labels.append(AGGR_IN_TAPAS[aggr])
                    break
            if not aggr_flag:
                aggregation_labels.append(0)
    if not supervise:
        data = {
            'id': pd.Series(id),
            'question': pd.Series(question),
            'table_file': pd.Series(table_file),
            'answer_coordinates': pd.Series(answer_coordinates),
            'answer_text': pd.Series(answer_text),
            'float_answer': pd.Series(float_answer)
        }
    else:
        data = {
            'id': pd.Series(id),
            'question': pd.Series(question),
            'table_file': pd.Series(table_file),
            'answer_coordinates': pd.Series(answer_coordinates),
            'answer_text': pd.Series(answer_text),
            'aggregation_labels': pd.Series(aggregation_labels)
        }

    df = pd.DataFrame(data)
    df.to_csv(dump_path, sep='\t', index=False)


def main():
    # select valid and annotated tables
    annotated_files = os.listdir(os.path.join(args.root_dir, args.data_dir, args.annotated_dir))
    valid_table_files = os.listdir(os.path.join(args.root_dir, args.data_dir, args.valid_table_dir))
    valid_annotated_files = [
        af
        for af in annotated_files
        if f"{af.split('.')[0]}.json" in valid_table_files
    ]
    # valid_annotated_files = valid_annotated_files[:20]
    valid_annotated_files = sorted(valid_annotated_files)  # WARNING! It'll change current data split.
    ic(len(valid_annotated_files))

    # iterate and parse the annotated Files
    dataset = {}
    for _, f in enumerate(tqdm(valid_annotated_files)):
        print(f"Parsing annotations from {f}.")
        table_id = f.split('.')[0]
        sample_dict = read_table_anno(
            filename=os.path.join(args.root_dir, args.data_dir, args.annotated_dir, f))
        if len(sample_dict) == 0:
            continue
        dataset[table_id] = sample_dict
    # to sample-wise format
    samples = to_qa_samples(dataset)

    # split into train/dev/test
    train_samples, dev_samples, test_samples = split_dataset(samples, method=args.split_method)
    print()
    ic(len(train_samples))
    ic(len(dev_samples))
    ic(len(test_samples))
    time.sleep(5)

    # annotate question tokens with stanford corenlp
    client = CoreNLPClient(
        annotators=['tokenize', 'pos', 'ner'],
        timeout=30000,
        memory='16G')
    label_dataset(train_samples, client)
    label_dataset(dev_samples, client)
    label_dataset(test_samples, client)

    # dump to tapas data
    tapas_dst_dir = os.path.join(args.root_dir, args.data_dir, args.tapas_output_dir)
    dump_to_tapas_tables(os.path.join(tapas_dst_dir, 'tables/'))
    dump_to_tapas_samples(train_samples, os.path.join(tapas_dst_dir, 'train_samples.tsv'))
    if args.supervise:
        dump_to_tapas_samples(train_samples, os.path.join(tapas_dst_dir, 'train_samples_sup.tsv'), supervise=True)
    dump_to_tapas_samples(dev_samples, os.path.join(tapas_dst_dir, 'dev_samples.tsv'))
    dump_to_tapas_samples(test_samples, os.path.join(tapas_dst_dir, 'test_samples.tsv'))
    print(f"Dump tapas data done.")

    # dump to mapo samples
    dst_dir = os.path.join(args.root_dir, args.data_dir, args.output_dir)
    dump_to_tsv(train_samples, os.path.join(dst_dir, 'train_meta.tsv'))
    dump_to_tsv(dev_samples, os.path.join(dst_dir, 'dev_meta.tsv'))
    dump_to_tsv(test_samples, os.path.join(dst_dir, 'test_meta.tsv'))
    dump_to_json(train_samples, os.path.join(dst_dir, 'train_meta.jsonl'))
    dump_to_json(dev_samples, os.path.join(dst_dir, 'dev_meta.jsonl'))
    dump_to_json(test_samples, os.path.join(dst_dir, 'test_meta.jsonl'))
    print(f"Dump MAPO data done.")


if __name__ == '__main__':
    parser = ArgumentParser()
    # path
    parser.add_argument('--root_dir', type=str, default='/data/home/hdd3000/USER/HMT/')
    parser.add_argument('--data_dir', type=str, default='qa/data/')
    parser.add_argument('--annotated_dir', type=str, default='annotations/')
    parser.add_argument('--nsf_annotated_dir', type=str, default='annotations_nsf/')
    parser.add_argument('--valid_table_dir', type=str, default='raw_input/table_filtered/')
    parser.add_argument('--html_dir', type=str, default='html/')
    parser.add_argument('--output_dir', type=str, default='raw_input/tagged/')
    parser.add_argument('--tapas_output_dir', type=str, default='raw_input/tapas_data/')
    parser.add_argument('--table_url_file', type=str, default='table_url_pairs.xlsx')
    # setting
    parser.add_argument('--operator', type=str, default=None)
    parser.add_argument('--supervise', action='store_true')
    parser.add_argument('--split_method', type=str, default='table', choices=['page', 'table', 'domain'])
    parser.add_argument('--max_diff_table_on_target', type=float, default=0.4)
    parser.add_argument('--max_diff_table', type=float, default=15)
    parser.add_argument('--max_aggr_type_dist', type=float, default=0.01)
    parser.add_argument('--seed', type=int, default=7)

    args = parser.parse_args()

    main()
