"""Pre-process the raw table-sentence annotations to model inputs. """

import os
import argparse

from processing.io_utils import load_tables, dump_dataset
from processing.io_utils import load_anno_and_struc, pairing, load_anno_nsf
from processing.link import get_raw_linked_cells, get_formula_linked_cells
from processing.serialize import linearize_table, linearize_data_paths
from processing.serialize import get_table_parent_list
from processing.base.structure_nsf import process


# %% Get Sample by Link and Serialization

# link
LINK_DICT = {
    'raw': get_raw_linked_cells, 
    'formula': get_formula_linked_cells
}


# serialize
def serial_concat(linked_cells, data_paths, structure, subdict):
    field_inputs = linearize_table(structure, linked_cells, operators=None)
    return field_inputs    # title, top, left, data, formula

def serial_pair(linked_cells, data_path_dicts, structure, subdict):
    answer_strings = [str(a) for a in subdict['answer']]
    field_inputs = linearize_data_paths(structure, data_path_dicts, 
        operators=subdict['aggregation'], answers=answer_strings)
    return field_inputs    # title, path, formula

SERIAL_DICT = {
    'concat': serial_concat, 
    'pair': serial_pair, 
}



def get_input_sample(table_id, table, args):
    """Create an instance of sample input given a table."""

    # read and parse annotation
    anno_path = os.path.join(args.annotated_dir, f'{table_id}.xlsx')
    struc_and_anno = load_anno_and_struc(anno_path, table_id, table, args.operator)
    if struc_and_anno is None: return
    structure, sentences = struc_and_anno

    # link and serialize each t-s pair
    for structure, subdict, sub_sent_id in pairing(structure, sentences):
        linked_cells, data_paths = LINK_DICT[args.link_method](structure, subdict)
        serial_input_list = SERIAL_DICT[args.serial_method](
            linked_cells, data_paths, structure, subdict)
        table_parent_list = get_table_parent_list(linked_cells, structure)

        # format the sample
        sample = {
            'table_id': table_id, 
            'sub_sent_id': sub_sent_id, 
            'source': serial_input_list, 
            'target': subdict['sub_sent'], 
            'table_parent': table_parent_list, 
            'operations': subdict['aggregation'], 
        }
        yield sample

def load_canada_dataset():
    # load tables 
    table_dict = load_tables(args.html_dir)
    # sort tables (w.r.t. id) for better stats
    sorted_table_list = [(table_id, table) for table_id, table in table_dict.items()]
    sorted_table_list = sorted(sorted_table_list, key=lambda x: x[0])
    if args.test_count is not None:
        sorted_table_list = sorted_table_list[:args.test_count]
    
    dataset = []    #  List[{'source', 'target'}]
    for table_id, table in sorted_table_list:
        for sample in get_input_sample(table_id, table, args):
            if sample is None: break
            dataset.append(sample)
    return dataset



def get_input_sample_nsf(nsf_table_id, nsf_page_id, nsf_anno_path, nsf_structure, args):
    """Create an instance of nsf table."""
    
    nsf_sentences = load_anno_nsf(anno_path=nsf_anno_path, operator=args.operator)
    if nsf_sentences is None: return

    # link and serialize each t-s pair
    for structure, subdict, sub_sent_id in pairing(nsf_structure, nsf_sentences):
        linked_cells, data_paths = LINK_DICT[args.link_method](structure, subdict)
        serial_input_list = SERIAL_DICT[args.serial_method](
            linked_cells, data_paths, structure, subdict)
        table_parent_list = get_table_parent_list(linked_cells, structure)

        # format the sample
        sample = {
            'table_id': nsf_table_id, 
            'page_id': nsf_page_id, 
            'domain_id': 0, 
            'sub_sent_id': sub_sent_id, 
            'source': serial_input_list, 
            'target': subdict['sub_sent'], 
            'table_parent': table_parent_list, 
            'operations': subdict['aggregation'], 
        }
        yield sample

def load_nsf_dataset():
    nsf_dataset = []

    global_table_id = 0
    num_error_case = 0
    page_dirnames = sorted(os.listdir(args.raw_nsf_dir))
    for page_dirname in sorted(page_dirnames):
        if page_dirname.endswith('N'): continue
        page_dir = os.path.join(args.raw_nsf_dir, page_dirname)
        table_names = os.listdir(page_dir)
        for tname in table_names:
            if tname.split('.')[0].endswith('N') or tname.startswith('~') or ('fig' in tname): continue

            try:
                nsf_structure, nsf_anno_path = process(
                    file_path=os.path.join(page_dir, tname), 
                    file_name=tname, 
                    page_id=page_dirname, 
                    table_id=global_table_id, 
                    new_nsf_dir=args.new_nsf_dir, 
                )
                for nsf_sample in get_input_sample_nsf(
                    nsf_table_id=10000+global_table_id, 
                    nsf_page_id=int(page_dirname), 
                    nsf_anno_path=nsf_anno_path, 
                    nsf_structure=nsf_structure, 
                    args=args
                ):
                    nsf_dataset.append(nsf_sample)
                global_table_id += 1
            except: 
                num_error_case += 1
                print(f"error processing {tname} on page {page_dir}.")
    return nsf_dataset




def main():
    canada_dataset = load_canada_dataset()
    nsf_dataset = load_nsf_dataset()

    print(f'[inputs >> main] collected {len(canada_dataset)} CANADA samples in the dataset.')
    # print(f'DATA #0: {canada_dataset[0]}\n')
    print(f'[inputs >> main] collected {len(nsf_dataset)} NSF samples in the dataset.')
    # print(f'DATA #0: {nsf_dataset[0]}\n')
    
    dump_dataset(
        dataset=canada_dataset+nsf_dataset, 
        pair_path=args.table_url_pairs_path, 
        data_dir=args.dataset_dir, 
        sample_dir=args.sample_dir, 
        names=[args.train_filename, args.valid_filename, args.test_filename], 
        method=args.split_method, 
        shuffle=True, 
        verbose=True, 
    )
    print(f'Bye~~')




from processing import _OPERATORS
operator_choices = [op.lower() for op in _OPERATORS]

from utils import get_dataset_path


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # inputs
    parser.add_argument('--dataset_subdir', type=str, default='../data')
    parser.add_argument('--html_subdir', type=str, default='canada_html')
    parser.add_argument('--annotated_subdir', type=str, default='canada_annotations')
    parser.add_argument('--raw_nsf_subdir', type=str, default='raw_nsf')
    parser.add_argument('--new_nsf_subdir', type=str, default='new_nsf')
    parser.add_argument('--valid_subdir', type=str, default='table_filtered')
    parser.add_argument('--table_url_pairs_name', type=str, default='table_url_pairs0424.xlsx')

    # output
    parser.add_argument('--dataset_name', type=str, default='both0513', 
        help='Sub-directory to store the pre-processed data2text source data.')
    parser.add_argument('--sample_name', type=str, default='bothsample0513', 
        help='Sub-directory to store a small amount of pre-processed samples.')
    parser.add_argument('--train_filename', type=str, default='train.json')
    parser.add_argument('--valid_filename', type=str, default='valid.json')
    parser.add_argument('--test_filename', type=str, default='test.json')

    # test utilities
    parser.add_argument('--test_count', type=int, default=None, 
        help="Maximum number of testing cases.")
    parser.add_argument('--operator', type=str, default=None, choices=operator_choices, 
        help='Specified single operator.')

    # process options
    parser.add_argument('--split_method', type=str, default='domain', 
        choices=['subsent', 'table', 'page', 'domain', 'website', 'load'])
    parser.add_argument('--link_method', type=str, default='formula', choices=['raw', 'formula'])
    parser.add_argument('--serial_method', type=str, default='concat', choices=['concat', 'pair'])

    args = parser.parse_args()

    args.annotated_dir = os.path.join(args.dataset_subdir, args.annotated_subdir)
    args.valid_dir = os.path.join(args.dataset_subdir, args.valid_subdir)
    args.html_dir = os.path.join(args.dataset_subdir, args.html_subdir)
    args.table_url_pairs_path = os.path.join(args.dataset_subdir, args.table_url_pairs_name)

    args.raw_nsf_dir = os.path.join(args.dataset_subdir, args.raw_nsf_subdir)
    args.new_nsf_dir = os.path.join(args.dataset_subdir, args.new_nsf_subdir)

    args.dataset_dir, args.sample_dir = get_dataset_path(args)
    if not os.path.exists(args.dataset_dir): os.makedirs(args.dataset_dir)
    if not os.path.exists(args.sample_dir): os.makedirs(args.sample_dir)
    print(f'Output Directory: \n{args.dataset_dir}\n{args.sample_dir}\n')

    main()
