"""Serialize the linked cells into a token sequence as the model inputs."""


# %% utility 

OPTIONS = {
    'caption': 'title',   # 'summary', 'both'
    'top': 'match',       #
    'left': 'match',      # 
    'corner': 'match',    # 
    'data': 'match',      #
}


def add_tag(text, tag, head=True, tail=True):
    """Add html-like tags to the text."""
    if head == True: prefix = f'{tag}'
    else: prefix = ''

    if tail == True: suffix = f'{tag}'
    else: suffix = ''

    return f'{prefix} {text} {suffix}'



# %% caption

import re

title_head_pattern1 = re.compile('Table [0-9]+:')
title_head_pattern2 = re.compile('Table [0-9]+')


def split_caption(caption):
    """Split the A1 caption into title and summary."""
    parts = [c.strip() for c in caption.split('\n')]
    if len(parts) == 2:
        title, summary = parts
    elif (len(parts) == 1) and ('Table summary' in caption):
        index = caption.index('Table summary')
        title = caption[: index].strip()
        summary = caption[index: ].strip()
    else:
        title = caption
        summary = ''
    return title, summary


# metadata caption
def linearize_caption(caption, option, tag='caption'):
    """Convert from the table caption a string.
    args:
        caption: f'Table...\nTable Summary...'
        option: choices = ['all', 'title', 'summary', 'none']
    rets:
        caption_str: str
    """
    title, summary = split_caption(caption)
    
    # try strip the heading 'Table x' of the title
    title_idx = 0
    span1 = re.match(title_head_pattern1, title)
    if span1 is not None:
        title_idx = span1.end()
    else:
        span2 = re.match(title_head_pattern2, title)
        if span2 is not None:
            title_idx = span2.end()
    title = title[title_idx: ].strip()

    
    if option == 'none':
        return ''

    if option == 'all':
        temp_str = f'{title}\t{summary}'
    elif option == 'title':
        temp_str = f'{title}'
    elif option == 'summary':
        temp_str = f'{summary}'
    else:
        raise ValueError(f'[serialize >> linearize_caption] got unexpected option {option}.')
    
    return add_tag(temp_str, tag)



# %% header

def linearize_header(header_cells, option, tag, sort_row_first=True):
    """Linearize a dict of header cells.
    args:
        header_cells: {(irow,icol): cell}
        option: choices = ['match', 'none']     # 'all'
        tag: choices = ['top', 'left']
        sort_row_first: bool, True if is top, False if is left
    rets:
        header_str: str
    """
    if option == 'match':
        sorted_headers = [(coords, hcell) for coords, hcell in header_cells.items()]
        if sort_row_first == True:
            sorted_headers = sorted(sorted_headers, key=lambda x: (x[0][0], x[0][1]))
        else:
            sorted_headers = sorted(sorted_headers, key=lambda x: (x[0][1], x[0][0]))
        sorted_header_values = [hcell.cell_string for coords, hcell in sorted_headers]
        # header_str = add_tag(
        #     text=' | '.join(sorted_header_values), 
        #     tag=tag
        # )
        # return header_str
        return sorted_header_values
    elif option == 'none':
        return []
    else:
        raise ValueError('[serialize >> linearize_header] got unexpected option [{option}]')


# %% data

def linearize_data(data_cells, option, tag='data'):
    """Linearize data cells based on the option.
    args:
        data_cells: {(irow, icol): data_cell}
        option: choices = ['match', 'none']
    rets:
        data_str: '<data> ... <data>'
    """
    if option == 'match':
        sorted_data = [(coords, dcell) for coords, dcell in data_cells.items()]
        sorted_data = sorted(sorted_data, key=lambda x: (x[0][0], x[0][1]))
        sorted_data_values = [dcell.cell_string for coords, dcell in sorted_data]
        # data_str = add_tag(
        #     text=' | '.join(sorted_data_values), 
        #     tag=tag, 
        # )
        # return data_str
        return sorted_data_values
    elif option == 'none':
        return []
    else:
        raise ValueError(f'[serialize >> linearize_data] got unexpected option [{option}]')


# %% operators and answers

def linearize_operators(operators, tag='operator'):
    if operators is None:
        op_str = ''
    else:
        op_str = add_tag(
            text=' '.join(operators), 
            tag=tag
        )
    return op_str


def linearize_answers(answer_texts, tag='answer'):
    if answer_texts is None:
        ans_str = ''
    else:
        ans_str = add_tag(
            text=' '.join(answer_texts), 
            tag=tag
        )
    return ans_str


def linearize_operation(operators, answers=None):
    op_str = linearize_operators(operators)
    ans_str = linearize_answers(answers)
    return f'{op_str} {ans_str}'


def listify_fielded_inputs(fielded_inputs):
    serial_input_list = []
    for field_name, field_value in fielded_inputs.items():
        if isinstance(field_value, str):
            serial_input_list.append(field_value)
        elif isinstance(field_value, list):
            serial_input_list.extend(field_value)
    serial_input_list = [text for text in serial_input_list if text != '']
    return serial_input_list



# %% row-wise traversal

def linearize_table(structure, linked_cells, 
    operators=None, options=OPTIONS, 
    return_type='list',   # ['dict', 'list', 'str']
):
    """[Version#1] linearization of linked cells."""

    caption_str = linearize_caption(
        caption=structure['caption'], 
        option=options['caption'], 
        tag='', 
    )

    # corner_str = linearize_corner(
    #     cell_matrix=structure['cell_matrix'], 
    #     nrows=structure['num_top_header_rows'], 
    #     ncols=structure['num_left_header_cols'], 
    #     tag='sep'
    # )

    # top_root = structure['top_header']['virtual_root']
    top_str_list = linearize_header(
        header_cells=linked_cells['top'], 
        option=options['top'], 
        tag=''
    )

    # left_root = structure['left_header']['virtual_root']
    left_str_list = linearize_header(
        header_cells=linked_cells['left'], 
        option=options['left'], 
        tag=''
    )

    data_str_list = linearize_data(
        data_cells=linked_cells['data'], 
        option=options['data'], 
        tag=''
    )

    # op_str = linearize_operation(operators)

    table_fields = {
        'title': caption_str, 
        'top': top_str_list, 
        'left': left_str_list, 
        'data': data_str_list, 
        'formula': '', 
    }
    if return_type == 'dict': return table_fields
    text_list = listify_fielded_inputs(table_fields)
    if return_type == 'list': return text_list
    return ' <sep> '.join(text_list)





# %% data-centered phrase

def linearize_data_with_path(path_dict):
    """
    args:
        path_dict: {'top': [], 'left': [], 'data': 'Cell'}
    rets:
        path_str: str
    """
    if len(path_dict['top']) == 0: top_str = ''
    else:
        top_str = add_tag(
            text=' | '.join([t.cell_string for t in path_dict['top']]), 
            tag='', head=False, tail=False
        )

    if len(path_dict['left']) == 0:
        left_str = ''
    else:
        left_str = add_tag(
            text=' | '.join([l.cell_string for l in path_dict['left']]), 
            tag='', head=False, tail=False
        )

    if path_dict['data'] is None:
        data_str = ''
    else:
        data_str = add_tag(
            text=path_dict['data'].cell_string,
            tag='', head=False, tail=False
        )
    
    str_list = [top_str, left_str, data_str]
    str_list = [s for s in str_list if s]
    path_str = ' ; '.join(str_list)
    return path_str


def linearize_data_paths(
    structure, path_dict_list, 
    operators=None, answers=None, 
    options=OPTIONS, 
    return_type='list',   # ['dict', 'list', 'str']
):
    caption_str = linearize_caption(
        caption=structure['caption'], 
        option=options['caption'], 
        tag='',
    )

    path_str_list = [
        linearize_data_with_path(path_dict)
        for path_dict in path_dict_list
    ]

    op_ans_str = linearize_operation(
        operators=operators, 
        answers=answers
    )

    path_fields = {
        'title': caption_str, 
        'path': path_str_list, 
        'formula': op_ans_str, 
    }
    if return_type == 'dict': return path_fields
    text_list = listify_fielded_inputs(path_fields)
    if return_type == 'list': return text_list
    return ' <sep> '.join(text_list)




# %% for the PARENT metrics

def clean_text(text):
    """Only has single blankspace as delimiters."""
    parts = text.split()
    parts = [p for part in parts for p in part.split('\t')]
    parts = [p for part in parts for p in part.split('\n')]
    cleaned_text = ' '.join(parts)
    return cleaned_text


def get_entry(attr, text):
    """Return table-parent entry: attr|||value """
    raw_value = clean_text(text)
    value = raw_value.replace('|', '-')
    entry = f'{attr}|||{value}'
    return entry

def get_tuple(attr, text):
    """Return table-parent entry: attr|||value """
    raw_value = clean_text(text)
    value = raw_value.replace('|', '-')
    return (attr, value)


def get_table_parent_str(linked_cells, structure):
    """Return a list of tuples as required by the PARENT metric.
    args:
        linked_cells: {'corner', 'top', 'left', 'data'}
    rets:
        *table_parent_array: List[Tuple(attribute, value)]
        table_parent_str: '\t'-separated
    """

    table_parent_array = []

    # linked headers and datas (corner, top, left, data)
    for attr, cells in linked_cells.items():
        for cell in cells:
            entry = get_entry(attr, cell.cell_string)
            table_parent_array.append(entry)
    
    # caption and summary
    title, summary = split_caption(structure['caption'])
    entry = get_entry('title', title)
    table_parent_array.append( entry )
    entry = get_entry('summary', summary)
    table_parent_array.append( entry )

    table_parent_str = '\t'.join(table_parent_array)
    return table_parent_str


def get_table_parent_list(linked_cells, structure):
    """Return a list of tuples as required by the PARENT metric.
    args:
        linked_cells: {'corner', 'top', 'left', 'data'}
    rets:
        *table_parent_array: List[Tuple(attribute, value)]
        table_parent_str: '\t'-separated
    """

    table_parent_array = []

    # linked headers and datas (corner, top, left, data)
    for attr, cells_dict in linked_cells.items():
        for pos, cell in cells_dict.items():
            cell_tuple = get_tuple(attr, cell.cell_string)
            table_parent_array.append(cell_tuple)
    
    # caption and summary
    title, summary = split_caption(structure['caption'])
    title_tuple = get_tuple('title', title)
    table_parent_array.append( title_tuple )
    summary_tuple = get_tuple('summary', summary)
    table_parent_array.append( summary_tuple )

    return table_parent_array

