"""Align Table and Sentence Annotations."""

import enum
import string
from nltk.stem.porter import PorterStemmer
from preprocess import text_utils, number_utils


class PieceTypes(enum.Enum):
    TEXT = 0
    NUMBER = 1

STOP_WORDS = {
    "i",
    "me",
    "my",
    "myself",
    "we",
    "our",
    "ours",
    "ourselves",
    "you",
    "your",
    "yours",
    "yourself",
    "yourselves",
    "he",
    "him",
    "his",
    "himself",
    "she",
    "her",
    "hers",
    "herself",
    "it",
    "its",
    "itself",
    "they",
    "them",
    "their",
    "theirs",
    "themselves",
    "what",
    "which",
    "who",
    "whom",
    "this",
    "that",
    "these",
    "those",
    "am",
    "is",
    "are",
    "was",
    "were",
    "be",
    "been",
    "being",
    "have",
    "has",
    "had",
    "having",
    "do",
    "does",
    "did",
    "doing",
    "a",
    "an",
    "the",
    "and",
    "but",
    "if",
    "or",
    "because",
    "as",
    "until",
    "while",
    "of",
    "at",
    "by",
    "for",
    "with",
    "about",
    "against",
    "between",
    "into",
    "through",
    "during",
    "before",
    "after",
    "above",
    "below",
    "to",
    "from",
    "up",
    "down",
    "in",
    "out",
    "on",
    "off",
    "over",
    "under",
    "again",
    "further",
    "then",
    "once",
    "here",
    "there",
    "when",
    "where",
    "why",
    "how",
    "all",
    "any",
    "both",
    "each",
    "few",
    "more",
    "most",
    "other",
    "some",
    "such",
    "no",
    "nor",
    "not",
    "only",
    "own",
    "same",
    "so",
    "than",
    "too",
    "very",
    "s",
    "t",
    "can",
    "will",
    "just",
    "don",
    "should",
    "now"
}

_LINK_KEYS = {'top', 'left', 'corner', 'data'}   # type == set

# %% String Match Utility

class StringMatchUtil(object):
    """A tool for examinig string-pair alignments. """

    @staticmethod
    def all_stop_words(pieces):
        return all(
            (p in STOP_WORDS or p in string.punctuation)
            for p in pieces
        )
    
    def contains_span(self, long_span, short_span):
        """Check if the longer span contains the shorter one."""
        nlong, nshort = len(long_span), len(short_span)
        if nlong < nshort:
            return self.contains(short_span, long_span)
        
        return any(
            (short_span == long_span[i: i + nshort])
            for i in range(nlong - nshort + 1)
        )

    def contains_ngram(self, text, ngram):
        """Find if a piece of text contains the given ngram."""
        text_length, ngram_length = len(text), len(ngram)
        return any(
            ngram == text[i: i + ngram_length]
            for i in range(text_length - ngram_length + 1)
        )

    def compute_span_similarity(self, span_a, types_a, span_b, types_b, epsilon=1e-6, stemmer=PorterStemmer()):
        """Compute the similarity score between two spans, using Jaccard coefficient."""
        if stemmer is not None:
            stemmed_span_a, stemmed_span_b = [], []
            for p, t in zip(span_a, types_a):
                if t == PieceTypes.TEXT:
                    stemmed_span_a.append(stemmer.stem(p))
                else:
                    stemmed_span_a.append(p)
            set_a = set(stemmed_span_a)
            
            for p, t in zip(span_b, types_b):
                if t == PieceTypes.TEXT:
                    stemmed_span_b.append(stemmer.stem(p))
                else:
                    stemmed_span_b.append(p)
            set_b = set(stemmed_span_b)
        else:
            set_a = set(span_a)
            set_b = set(span_b)

        nominator = set_a & set_b
        denominator = set_a | set_b
        return len(nominator) / (len(denominator) + epsilon)
    
    def get_all_ngrmas(self, text, ngram_length):
        text_length = len(text)
        return [
            text[i: i + ngram_length]
            for i in range(text_length - ngram_length + 1)
        ]

    def compute_ngram_similarity(self, str_a, str_b, ngram_length, epsilon=1e-6):
        scores = [
            int(self.contains_ngram(str_a, ngram))
            for ngram in self.get_all_ngrmas(str_b, ngram_length)
        ]
        return float(sum(scores)) / (len(scores) + epsilon)

    def split_text(self, text, 
        text_normalize_fn=text_utils.normalize_text, 
        number_normalize_fn=number_utils.normalize_number, 
        return_exact=False
    ):
        """Split a piece of text into a listed mixture of textual words and numeric entities (numbers). 
        
        Args:
            text: str, e.g. 
                "As a consequence of the Chinese ban on Canadian pork, 
                Canada had only a 0.55% increase in total Canadian exports between 2018 and 2019. "
            text_normalize_fn: function to normalize a piece of text into a fairly-comparable format.
            number_normalize_fn: function to normalize a numeric value.
        Returns:
            pieces: list[str]
            types: List[PieceTypes]
            str_n_num: List[Union[str, float/int]]
        """
        if text is None:
            text = ''
        parts = [p.strip() for p in text.strip().split() if p]
        pieces, types, str_n_num = [], [], []

        for pstr in parts:
            pnum = number_utils.parse_number(pstr)
            if pnum is not None:
                pieces.append(pstr)
                types.append(PieceTypes.NUMBER)
                str_n_num.append(number_normalize_fn(pnum))
            else:
                pieces.append(text_normalize_fn(pstr))
                types.append(PieceTypes.TEXT)
                str_n_num.append(number_normalize_fn(pnum))

        if return_exact == True:
            return pieces, types, str_n_num
        return pieces, types

    def compare_similarity(self, text_a, text_b, ngram_length=3, epsilon=1e-6):
        """Compute span and ngram similarity scores."""
        pieces_a, types_a = self.split_text(text_a)
        pieces_b, types_b = self.split_text(text_b)
        span_score = self.compute_span_similarity(pieces_a, types_a, pieces_b, types_b, epsilon)
    
        str_a = ' '.join(pieces_a)
        str_b = ' '.join(pieces_b)
        ngram_score = self.compute_ngram_similarity(str_a, str_b, ngram_length, epsilon)

        return span_score, ngram_score

aligner = StringMatchUtil()

stemmer = PorterStemmer()

def find_corecell(cell_matrix, irow, icol):
    """Find the top-left core-cell of the current cell span."""
    curr_cell = cell_matrix[irow][icol]

    srow = curr_cell.span['start_row']
    scol = curr_cell.span['start_col']
    core_cell = cell_matrix[srow][scol]
    assert core_cell.iscorecell

    return core_cell


# %% Basic Search and Align

def search_corner(cell_matrix, item, nrows, ncols, similarity_threshold=0.9):
    """Traverse the top-left corner sub-matrix to find text-matching cells.
    Note: 'item' may be of str/int/float types.

    Returns: {(irow,icol): str}
    """

    linked_corners = {}

    if not isinstance(item, str): item = str(item)

    for irow in range(nrows):
        for icol in range(ncols):
            cell = cell_matrix[irow][icol]
            if not cell.iscorecell: 
                cell = find_corecell(cell)

            span_score, ngram_score = aligner.compare_similarity(
                cell.cell_string, item)
            if min(span_score, ngram_score) > similarity_threshold:
                linked_corners[(irow, icol)] = cell.cell_string

    return linked_corners


def search_tree(node, item, similarity_threshold=0.9):
    """Traverse the tree nodes to find exact-match texts.
    Note: 'item' may be of str/int/float types.

    Returns: {(irow,icol): str}
    """

    tree_dict = {}

    node = find_corecell(node)
    if not isinstance(item, str): item = str(item)

    span_score, ngram_score = aligner.compare_similarity(
        node.cell_string, item)
    if min(span_score, ngram_score) > similarity_threshold: 
        irow = node.span['start_row']
        icol = node.span['start_col']
        tree_dict[(irow, icol)] = node

    for i, child in enumerate(node.child_cells):
        child_tree_dict = search_tree(child, item)
        tree_dict.update(child_tree_dict)
    
    return tree_dict


def search_data(cell_matrix, item, srow, scol, similarity_threshold=0.9): 
    """Traverse the data region to find exact-match data cells.
    Args:
        item: str, int, float
        srow: starting row index of the data region
        scol: starting column index of the data region
    """

    data_dict = {}

    mrow = len(cell_matrix)
    mcol = len(cell_matrix[0])

    for irow in range(srow, mrow):
        for icol in range(scol, mcol): 
            cell = cell_matrix[irow][icol]
            if not cell.iscorecell: continue
            cellstr = cell.cell_string
            assert isinstance(cellstr, str)

            if isinstance(item, float):
                try:
                    cell_float = float(cellstr)
                    if round(item, 2) == round(cell_float):
                        data_dict[(irow, icol)] = cellstr
                except:
                    continue
            elif isinstance(item, int):
                try:
                    cell_int = int(cellstr)
                    if item == cell_int:
                        data_dict[(irow, icol)] = cellstr
                except:
                    continue
            elif isinstance(item, str):
                assert isinstance(cellstr, str)
                span_score, ngram_score = aligner.compare_similarity(item, cellstr)
                if min(span_score, ngram_score) > similarity_threshold:
                    data_dict[(irow, icol)] = cellstr
            else:
                raise ValueError(f"Linked Item {item} has Unexpected Data Type {type(item)}...")
    
    return data_dict
                


# %% Link by Reading Annotations

def update_linked_dict(input_dict, linked_cells, 
    cell_matrix, hdr_nrows, hdr_ncols):
    """Update 'linked_cells' with the 'input_dict'."""

    nrows = len(cell_matrix)
    ncols = len(cell_matrix[0])

    for (irow,icol), item in input_dict.items():
        if item is None: continue
        if (irow >= nrows) or (icol >= ncols): continue

        if (irow < hdr_nrows) and (icol < hdr_ncols):  # corner
            corner_cell = find_corecell(cell_matrix, irow, icol)
            # if corner_cell.cell_string != str(item):
            #     print(f"got unmatching CORNER [{corner_cell.cell_string}] and ITEM [{item}]")
            crow = corner_cell.span['start_row']
            ccol = corner_cell.span['start_col']
            linked_cells['corner'].update({
                (crow, ccol): corner_cell
            })
        elif (irow < hdr_nrows) and (icol >= hdr_ncols):  # top
            top_cell = find_corecell(cell_matrix, irow, icol)
            # if top_cell.cell_string != str(item):
            #     print(f"got unmatching TOP [{top_cell.cell_string}] and ITEM [{item}]")
            trow = top_cell.span['start_row']
            tcol = top_cell.span['start_col']
            linked_cells['top'].update({
                (trow, tcol): top_cell
            })
        elif (irow >= hdr_nrows) and (icol < hdr_ncols):  # left
            left_cell = find_corecell(cell_matrix, irow, icol)
            # if left_cell.cell_string != str(item):
            #     print(f"got unmatching LEFT [{left_cell.cell_string}] and ITEM [{item}]")
            lrow = left_cell.span['start_row']
            lcol = left_cell.span['start_col']
            linked_cells['left'].update({
                (lrow, lcol): left_cell
            })
        elif (irow >= hdr_nrows) and (icol >= hdr_ncols):  # data
            data_cell = find_corecell(cell_matrix, irow, icol)
            # if data_cell.cell_string != str(item):
            #     print(f"got unmatching DATA [{data_cell.cell_string}] and ITEM [{item}]")
            drow = data_cell.span['start_row']
            dcol = data_cell.span['start_col']
            linked_cells['data'].update({
                (drow, dcol): data_cell
            })
    return linked_cells


def link_schema(structure, subdict, hdr_nrows, hdr_ncols):
    """Read schema linking from subdict, locate corresponding cells in structure.
    """
    linked_cells = {k: {} for k in _LINK_KEYS}

    cell_matrix = structure['cell_matrix']
    schema_link = subdict['schema_link']   # {'phrase': {(irow,icol): item, ...}}
    for phrase, phrase_dict in schema_link.items():
        linked_cells = update_linked_dict(phrase_dict, linked_cells, 
            cell_matrix, hdr_nrows, hdr_ncols)

    return linked_cells


def link_answer(structure, subdict, hdr_nrows, hdr_ncols):
    """Parse formula from subdict, extract out data regions mentioned.
    @USER TODO: fix 'top' and 'left' positions, from (irow, icol) to (ilevel, index).
    """
    linked_cells = {k: {} for k in _LINK_KEYS}

    linked_cells = update_linked_dict(
        input_dict=subdict['answer_cells'], 
        linked_cells=linked_cells, 
        cell_matrix=structure['cell_matrix'], 
        hdr_nrows=hdr_nrows, 
        hdr_ncols=hdr_ncols
    )

    return linked_cells


def merge_linked_cells(linked_cells_list):
    """Merge and Unique multiple linked-cells items."""

    merged_linked_cells = {k: {} for k in _LINK_KEYS}

    for linked_cells in linked_cells_list:
        for k, val_dict in linked_cells.items():
            merged_linked_cells[k].update(val_dict)

    return merged_linked_cells



# %% Cross-Linking between Header and Data

# Add New Header Nodes by Reading Data

def match_node_to_cell(node, cell):
    if (node.span == cell.span) and node.iscorecell: 
        return node

    for j, child in enumerate(node.child_cells):
        child_candidate = match_node_to_cell(child, cell)
        if child_candidate is not None: return child_candidate

    return None


def map_cell_to_leaf(root, cell):
    """Traverse from 'root' of the tree, for a map of 'node'."""
    for node in root['virtual_root']:
        candidate = match_node_to_cell(node, cell)
        if candidate is not None: return candidate
    return None


def update_header(linked_cells, structure, hdr_nrows, hdr_ncols):
    """Retrieve the indexing headers based on the linked data cells.
    Args:
        linked_cells: {
            'corner': {(irow, icol): cell}, 'data': {(irow, icol): cell}, 
            'top': {(irow, icol): cell}, 'left': {(irow, icol): cell}
        }
    Returns: 
        ('top'&'left' updated) linked_cells: {}
    """

    cell_matrix = structure['cell_matrix']

    for (irow, icol), data_cell in linked_cells['data'].items():
        # find last row/col header
        # find corresponding lead nodesm update dict
        top_cell = find_corecell(cell_matrix, hdr_nrows-1, icol-1)
        top_leaf = map_cell_to_leaf(structure['top_header'], top_cell)
        if top_leaf is not None:
            top_irow = top_leaf.span['start_row']
            top_icol = top_leaf.span['start_col']
            if (top_irow, top_icol) not in linked_cells['top']:
                # print(f"update [TOP] with [{top_leaf.cell_string}] at [{(top_irow,top_icol)}]")
                linked_cells['top'][(top_irow, top_icol)] = top_leaf

        left_cell = find_corecell(cell_matrix, irow-1, hdr_ncols-1)
        left_leaf = map_cell_to_leaf(
            structure['left_header'], left_cell)
        if left_leaf is not None:
            left_irow = left_leaf.span['start_row']
            left_icol = left_leaf.span['start_col']
            if (left_irow, left_icol) not in linked_cells['left']:
                # print(f"update [LEFT] with [{left_leaf.cell_string}] at [{(left_irow,left_icol)}]")
                linked_cells['left'][(left_irow, left_icol)] = left_leaf

    return linked_cells 






def expand_header(header_dict, hdr_nrows, hdr_ncols):
    """Expand the selected header region to a connected region.
    For leaf nodes, make sure all of their ascendents are selected.
    For non-leaf nodes, NOT selecting their descendents FOR NOW.

    Args:
        header_dict: {(ilevel, index): node, ...}
        max_level: num_top_header_rows/num_left_header_cols
    Return: 
        (updated) header_dict
    """

    added_nodes = {}

    for (irow, icol), node in header_dict.items():
        if (irow == hdr_nrows - 1) or (icol == hdr_ncols):
            # add all ascendents of leaf nodes
            pnode = node.parent_cell[0]

            while pnode != 'virtual_root': 
                ppnode = pnode.parent_cell[0]
                
                prow = pnode.span['start_row']
                pcol = pnode.span['start_col']
                added_nodes[(prow, pcol)] = pnode
                
                pnode = ppnode
        # else:   # add itself only, already in header-dict
        #     continue
    
    header_dict.update(added_nodes)
    return header_dict


# Add New Data Cells by Reading Headers

def map_node_to_cell(node, cell_matrix):
    """Map a tree node (of type 'Cell') to a cell in 'cell_matrix'."""
    irow = node.span['start_row'] 
    icol = node.span['start_col']
    cell = cell_matrix[irow][icol]

    assert cell.iscorecell
    assert node.cell_string == cell.cell_string
    
    return cell, (irow, icol)


def update_data(linked_cells, structure):
    """Augment the header-node selection based on the data cells.
    Args:
        linked_cells: {
            'top': {(ilevel, index): node}, 'left': [(ilevel, index): node], 
            'corner': {(irow, icol): cell}, 'data': {(irow, icol): cell}
        }
    Returns: 
        ('data' updated) linked_cells: {}

    Notes: added cells may fall into the top-left corner region.
    """
    
    cell_matrix = structure['cell_matrix']

    top_icols = []
    for (ilevel, index), topnode in linked_cells['top'].items():
        topcell, (top_irow, top_icol) = map_node_to_cell(topnode, cell_matrix)
        top_icols.append(top_icol)
    
    left_irows = []
    for (ilevel, index), leftnode in linked_cells['left'].items():
        leftcell, (left_irow, left_icol) = map_node_to_cell(leftnode, cell_matrix)
        left_irows.append(left_irow)

    cross_data_locs = [(r, c) for r in left_irows for c in top_icols]
    for (r, c) in cross_data_locs:
        if (r, c) in linked_cells['data']: continue
        data_cell = cell_matrix[r][c]
        if not data_cell.cell_string: continue
        linked_cells['data'][(r, c)] = data_cell

    return linked_cells




# %% Link a Sub-Sentence Dict to Table Cell Dicts

def link_table_cells(structure, subdict):
    """Select relevant cells from the table header and data regions.

    A1. using the 'Schema Linking' annotations.
    A2. using the data region mentioned in the 'Formula' etc.
    A3. merge both (or more) linked cells dict.

    B. NOT YET *heuristic string-matching of header & data.
    B. NOT YET merge the above results, by header-data cross indexing.

    Args:
        structure: {}
        subdict: {
            'sub_sent_cmp', 'sub_sent_after', 'key_part', 
            'question', 'answer', 'aggregations'
        }
    Returns:
        linked_cells_dict: {
            'top'/'left': {(irow, icol): node}, 
            'corner'/'data': {(irow, icol): cell}
        }
    """

    # read annotation and merge multiple results
    hdr_nrows = structure['num_top_header_rows']
    hdr_ncols = structure['num_left_header_cols']

    schema_links = link_schema(structure, subdict, hdr_nrows, hdr_ncols)
    answer_links = link_answer(structure, subdict, hdr_nrows, hdr_ncols)
    linked_cells = merge_linked_cells([schema_links, answer_links])

    # cross-linking header and data, to ensure a complete sub-table
    ## selected more header node, based on data cells
    # linked_cells = update_header(linked_cells, structure, hdr_nrows, hdr_ncols)

    ## expand top & left header regions to connected regions
    # linked_cells['top'] = expand_header(
    #     linked_cells['top'], hdr_nrows, hdr_ncols)
    # linked_cells['left'] = expand_header(
    #     linked_cells['left'], hdr_nrows, hdr_ncols)
    
    ## add all data cells, cross-indexed by the headers
    # linked_cells = update_data(linked_cells, structure)

    return linked_cells


# TODO maybe?
def limit_linked_shape(linked_dict):
    """RCheck the shape of linked sub-table based on the annotated aggregation type. 
    UNARY (none, opp): single data cell
    BINARY (diff, div): a pair of cells
    AGGR (sum, avg; count; max, min): 1. strict: a row/col of cells, 2. loose: multiple
    ARG argrank(headerS, dataS, return_index)
    """
    return

# %% General Pipeline

# generate unique id
def get_subsent_id(table_id, sent_desc_id):
    """Construct using a common format the table-subsent unique id."""
    return f"{table_id}-{sent_desc_id}"

def get_subsent_desc_id(sent_id, isub):
    """Create a unique subsent id (w/o table info)."""
    return f"{sent_id}-{isub+1}"



def align_table_annotation(table_id, structure, sentences):
    """Align table with each sub-sentence annotation.

    Args:
        *from dual_table_dict: {
            'structure': table_struc, 
            'sentences': sentences_dict = {
                'table_desc_sent_id', 'table_desc_sent', 
                'sub_samples': [
                    'sub_sent_cmp', 'sub_sent_after', 'key_part', 'schema_links', 
                    'question', 'answer_cells', 'answer', 'aggregations'
                ]
            }
        }
    *Returns: List of [ linked_cells ]
        linked_cells: {
            'top_header': [(ilevel, index, cell), ...]}, 
            'left_header': [(ilevel, index, cell), ...]}, 
            'data': [(irow, icol, cell), ...]
        }
    Yields: structure, linked_cells (for each sub-sentence-dict)
    """

    # iterate over sub-sentences
    for sdict in sentences:
        sent = sdict['table_desc_sent']
        sent_id = sdict['table_desc_sent_id']

        for isub, subdict in enumerate(sdict['sub_samples']):
            assert subdict['sub_sent_comp'] is not None

            linked_cells = link_table_cells(structure, subdict)

            table_subsent_id = get_subsent_id(
                table_id=table_id, 
                sent_desc_id=get_subsent_desc_id(sent_id, isub)
            )

            linked_dict = {
                'structure': structure, 
                'subdict': subdict, 
                'linked_cells': linked_cells, 
            }
            yield table_subsent_id, linked_dict
