"""Read and Parse each sub-sentence as a sub-block."""

import re

from typing import Tuple
from openpyxl.formula import Tokenizer
import Levenshtein

from processing import _OPERATORS
from processing.annotation import (
    _SUB_SENT_COMP, _SUB_SENT_AFTER, _KEY_PART, _SCHEMA_LINKING_PHRASES, 
    _SCHEMA_LINKING_POSITIONS, _QUESTION_REWRITE, _ANSWER, _AGGREGATION_TYPE,
    _COLOR_WHITE
)
from processing.annotation.block import strip_block
from processing.message_utils import (AnnoMessage, _LOSS, _WRONG, 
    _ANNO_ANSWER, _ANNO_SENTENCE, _ANNO_KEYPART, 
)
from processing.position_utils import LinkedPosition, stretch_range


# %% Main 

def get_sub_samples(block_start, worksheet, worksheet_literal, operator=None):
    """Parse and Heuristically Clean the Block Samples."""

    sub_blocks = get_subsent_blocks(
        block=(block_start['start_row'] + 2, block_start['end_row']), 
        worksheet=worksheet
    )

    sub_samples, sub_msgs = [], []
    for start, end in sub_blocks:
        sub_sample, sub_msg = parse_sub_sample(
            start, worksheet, worksheet_literal, operator=operator)
        sub_samples.append(sub_sample)
        sub_msgs.append(sub_msg)

    assert len(sub_samples) == len(sub_msgs)
    block_start.update({
        'sub_samples': sub_samples, 
        'sub_messages': sub_msgs, 
    })
    return block_start



# %% 1. split a block into several sub-blocks
def assert_subblock(block, worksheet):
    """Assert the range accords to the format of a sub-sentence block."""
    s, e = block

    colors_to_check = [worksheet.cell(s+i, j).fill.fgColor.rgb for i in range(0, 8) for j in range(1, 3)]
    highlight = any([c != _COLOR_WHITE for c in colors_to_check])
    if (
        (s + 7 == e) and
        (str(worksheet.cell(s, 1).value.strip() )== _SUB_SENT_COMP) and
        (str(worksheet.cell(s + 1, 1).value.strip()) == _SUB_SENT_AFTER) and 
        (str(worksheet.cell(s + 2, 1).value.strip()) == _KEY_PART) and  
        (str(worksheet.cell(s + 3, 1).value.strip()) == _SCHEMA_LINKING_PHRASES) and 
        (str(worksheet.cell(s + 4, 1).value.strip()) == _SCHEMA_LINKING_POSITIONS) and 
        (str(worksheet.cell(s + 5, 1).value.strip()) == _QUESTION_REWRITE) and 
        (str(worksheet.cell(s + 6, 1).value.strip()) == _ANSWER) and 
        (str(worksheet.cell(s + 7, 1).value.strip()) == _AGGREGATION_TYPE) and
        not highlight
    ): 
        return True
    else:
        return False
        

def get_subsent_blocks(block: Tuple[int, int], worksheet):
    """Get the list of sub-sentence ranges of the same desc sentence."""
    sub_blocks = []

    start, end = block
    s = start

    while s < end:
        e = s + 7
        while e < end:
            if worksheet.cell(e, 1).value == _SUB_SENT_COMP: 
                sub_blocks.append( (s, e - 1) )
                s = e
                break
            e += 1
        if e == end and s < e:
            sub_blocks.append( (s, e) )
            break

    stripped_sub_blocks = []
    for sb in sub_blocks:
        ssb = strip_block(sb, worksheet)
        if ssb is None: continue
        if not assert_subblock(ssb, worksheet): continue
        
        stripped_sub_blocks.append(ssb)

    return stripped_sub_blocks


# %% 2. parse

def parse_sub_sample(start, worksheet, worksheet_literal, 
    operator=None, 
    schema_offset=3, answer_offset=6, aggregation_offset=7, 
):
    """Read sub-sentence annotations.
    Remove cases without 'key part' and 'question' and 'answer_cell's.
    """

    sub_dict, msg = start_subdict(start, worksheet, worksheet_literal)
    if (sub_dict is None) and (msg is not None): return None, msg

    schema_links = extract_schema_links(start + schema_offset, worksheet, worksheet_literal)
    # valid_schema_links = check_schema_link(schema_links, sub_dict)

    answer_cells, answer = extract_answer(start + answer_offset, worksheet, worksheet_literal)
    if len(answer_cells) == 0: 
        msg = AnnoMessage(
            msg_type=_LOSS, msg_object=_ANNO_ANSWER, sub_block_start=start)
        return None, msg

    valid_aggr_types = extract_aggregation(start + aggregation_offset, worksheet)
    if (operator is not None) and (operator not in valid_aggr_types): return None, msg
    # if len(valid_aggr_types) > 1: print(f'compound agg types: {valid_aggr_types}')

    sub_dict.update({
        'schema_link': schema_links, 
        'answer_cells': answer_cells,
        'answer': answer, 
        'aggregation': valid_aggr_types,
    })
    return sub_dict, msg



# %% 2.1 Start a sub-block dict with sentence and key part (normalized)

def restore_key_part(key_part_cell):
    """Restore the key_part text given the worksheet cell.
    e.g. change [0.19] to [19%]
    """
    key_value = key_part_cell.value
    if (key_value is None): return None
    key_str = str(key_value)
    key_part = key_str

    if key_part_cell.number_format.endswith('%') and isinstance(key_value, float): 
        digit, decimal = key_str.split('.')
        if key_part_cell.number_format == '0.00%':
            new_digit = decimal[:2]
            if set([c for c in new_digit[:-1]]) == set(['0']):   # strip the heading zeros
                new_digit = new_digit[-1:]

            if len(decimal) == 2:
                key_part = f'{new_digit}.0%'
            else:
                key_part = f'{new_digit}.{decimal[2:]}%'
        elif key_part_cell.number_format == '0%':
            key_part = f'{(key_value * 100):.0f}%'
        
    return key_part


def check_key_part(key_part, sub_sent_list):
    """Check if key-part can be indexed in either sentences."""

    if key_part is None:
        # print(f'[LOSS] no key part annotation.')
        return -1, None, None

    key_part_lower = key_part.lower()
    for sub_sent in sub_sent_list:
        if (sub_sent is None) or not isinstance(sub_sent, str): continue
        sub_sent_lower = sub_sent.lower()
        if key_part_lower in sub_sent_lower:
            key_index = sub_sent_lower.index(key_part_lower)
            # print(f'SUCCESS find an exact match of [{key_part}] in subsent [{sub_sent}]')
            return key_index, sub_sent, key_part
        
        # thousands
        if all([(char.isdigit()) for char in key_part_lower]):
            num_parts = len(key_part_lower) // 3
            offset = len(key_part_lower) % 3
            if offset == 0:
                offset = 3
                num_parts -= 1
            
            key_parts = [key_part_lower[:offset]] + [
                key_part_lower[offset + i*3: offset + (i+1)*3] 
                for i in range(num_parts)
            ]
            thou_key_part = ','.join(key_parts)
            if thou_key_part in sub_sent_lower:
                key_index = sub_sent_lower.index(thou_key_part)
                # print(f'SUCCESS find an THOUSAND match of [{thou_key_part}] in subsent [{sub_sent}]')
                return key_index, sub_sent, thou_key_part

    # print(f'FAIL no exact match of [{key_part}] in subsents {sub_sent_list}..')
    return -1, None, None


def start_subdict(start, worksheet, worksheet_literal):
    """Basic requirements for the sub-sentence annotation."""
    msg = None

    sub_sent_comp = worksheet.cell(start, 2).value
    sub_sent_after = worksheet.cell(start + 1, 2).value
    question = worksheet.cell(start + 5, 2).value
    if (sub_sent_comp is None) and (sub_sent_after is None) and (question is None): 
        msg = AnnoMessage(
            msg_type=_LOSS, 
            msg_object=_ANNO_SENTENCE, 
            sub_block_start=start, 
        )
        # print(f'[LOSS] no available sentences and/or question.')
        return None, msg 

    key_part_cell = worksheet.cell(start + 2, 2)
    key_part = restore_key_part(key_part_cell)
    key_index, sub_sent, key_part = check_key_part(key_part, 
        sub_sent_list=[sub_sent_comp, sub_sent_after])
    if (sub_sent is None):   # (key_index == -1) or 
        msg = AnnoMessage(
            msg_type=_WRONG, 
            msg_object=_ANNO_KEYPART, 
            sub_block_start=start, 
        )
        # print(f'[WRONG] cannot match the key part.')
        return None, msg
    
    subdict = {
        # position within the worksheet
        'start_row': start, 
        # direct copy from worksheet, prone to 'None'
        'sub_sent': sub_sent,      
        'key_part': key_part, 
        'question': question,    
        # key part index
        'key_index': key_index
    }
    return subdict, msg



# %% 2.2 Schema Link

def extract_linked_cell(phrase, position, worksheet):
    """Extract (coordinate, item) pairs from the annotated 'schema linking positions'."""

    if isinstance(position, str) and (position[0] == '='): 
        position = position[1:]

    try:
        positions = position.split(',')
        items = [worksheet[pos].value for pos in positions]
    except:
        items = [phrase]
        positions = [position]

    phrase_link = {}
    for pos, itm in zip(positions, items):
        try:
            linked_pos = LinkedPosition.from_ws_str(pos)
            if linked_pos.ws_coords[0] < 3: continue
            phrase_link[linked_pos.cm_coords] = itm
        except: continue

    return phrase_link


def extract_schema_links(row_index, worksheet, worksheet_literal):
    """Extract from two rows (schema linking phrases & positions)."""
    schema_link = {}

    phrase_in_memory = ''
    col_index = 2

    while True:
        phrase = worksheet.cell(row_index, col_index).value
        position = worksheet.cell(row_index + 1, col_index).value
        if (not position): break   # (not phrase) and 
        if phrase:
            phrase_in_memory = phrase
        if phrase_in_memory not in schema_link:
            schema_link[phrase_in_memory] = {}
        
        phrase_link = extract_linked_cell(phrase, position, worksheet)
        schema_link[phrase_in_memory].update(phrase_link)
        # print(f'Row #{row_index + 1} Col #{col_index} Collect {len(pos_text_pairs)} pos-text pairs of Phrase {phrase_in_memory}')

        col_index += 1
    
    # print(f"Collected {len(schema_link)} phrases: {list(schema_link.keys())}")
    return schema_link



# %% 2.3 Answer Region

_REGION = re.compile('[a-zA-Z]+[0-9]+')

def split_formula_items(formula):   # deprecated
    """Split formula items ' A1, B1:B2, C1-D1' to a list ['A1', 'B1:B2', 'C1-D1'].
    Notes: input values have no '='. returned items always start with '='.
    """
    valid_items = [
        f'={item}'
        for item in re.findall(_REGION, formula)
    ]
    return valid_items


# parse answer cell
def parse_formula_item(item):   # @USER: TODO
    """Parse an formula to get pairs. use the Formula Tokenizer of openpyxl.
    ONLY parse out the cell regions for now.
    @USER TODO: further extract the OPERATORS.
    """
    formula_coords = set()

    token = Tokenizer(item)
    for item in token.items:
        if (item.type == 'OPERAND') and (item.subtype == 'RANGE'):  # 'A1:'
            strs = item.value.split(':')
            range_coords = stretch_range(strs)   # 1-indexed ws coords
            formula_coords = formula_coords.union( set(range_coords) )

    pairs = [(c, None) for c in formula_coords]
    return pairs


def parse_answer_cell(item):
    """Parse a single answer cell with formula, to get relevant table regions.
    Args:
        answer_cell: text value of an answer cell.
        1. can have '=' at the beginning
        2. 
    Returns: 
        [('A2', 'Weather'), ...]
    """

    pairs = []
    
    if (not isinstance(item, str)) or (item[0] != '='): 
        pairs.append( (None, item) )
        return pairs

    # formula, str startswith '='
    pairs = parse_formula_item(item)

    if len(pairs) == 0:
        # print(f'cannot directly parse FORMULA [{item}], try splitting..')
        formula = item[1:]
        fitems = split_formula_items(formula)
        for fitem in fitems:
            pos_item_pairs = parse_formula_item(fitem)
            pairs.extend(pos_item_pairs)

    # print(f'from FORMULA [{item}] get PosItem pairs: {pairs}')
    return pairs
    

def parse_answer_literal(answer_literal, worksheet_literal):
    """Parse the literal version of answer cell.
    Edge cases usually covered by when 'answer_cell == answer_literal'.
    e.g. 'Manitoba', 'A13,A8'
    """
    if isinstance(answer_literal, str):
        regions = split_formula_items(answer_literal)
        if len(regions) > 0:
            answers = [
                worksheet_literal[f_ws_str[1:]].value
                for f_ws_str in regions
            ]
            return answers
    elif isinstance(answer_literal, float):
        return [round(answer_literal, 6)]

    return [answer_literal]


# post answwer 
def find_match_cell(item, worksheet, worksheet_literal):
    """Return the coordinate of the textually matched cell.
    Return (-1, -1) if no exact matching cells.
    """
    nrows = worksheet.max_row
    ncols = worksheet.max_column

    for irow in range(1, nrows + 1):
        for icol in range(1, ncols + 1):
            wsitem = worksheet.cell(irow, icol).value
            wslit = worksheet_literal.cell(irow, icol).value
            if wsitem == item or wslit == str(item):
                return (irow, icol)
    return (-1, -1)


def post_complete_pairs(answer_pairs, worksheet, worksheet_literal):
    """Post-Complete the (Coords, Item) pairs with 'None' using the worksheet."""

    complete_answer_dict = {}

    for ws_coords, item in answer_pairs:
        if (ws_coords is None) and (item is None): continue

        if item is None:   # find the worksheet cell based on position
            prow, pcol = ws_coords
            if prow < 3: continue

            item = worksheet.cell(prow, pcol).value
            if item is None: continue

            linked_pos = LinkedPosition.from_ws_coords(ws_coords)
            dpair = {linked_pos.cm_coords: item}
        else:  # text is not None
            trow, tcol = find_match_cell(item, worksheet, worksheet_literal)
            if (trow == -1) and (tcol == -1): 
                # print(f"[AnswerPostComplete] cannot find matching coordinates for ITEM [{item}]")
                continue
            if trow < 3: continue

            if ws_coords is not None:    # both position and item are applicable
                prow, pcol = ws_coords   # 1-index coordinate
                if (prow != trow) or (pcol != tcol): 
                    # print(f"[AnswerPostComplete] match [{item}] at [{(trow,tcol)}], yet previously parses [{prow,pcol}]")
                    continue
            
            linked_pos = LinkedPosition.from_ws_coords(ws_coords=(trow, tcol))
            dpair = {linked_pos.cm_coords: item}

        complete_answer_dict.update(dpair)
    return complete_answer_dict        


# extract answer
def extract_answer(row_index, worksheet, worksheet_literal):
    """Get (Cell Position, Literal) pairs from (a row of) annotated answer cells.
    Iteratively parse answer cells, then post-complete the linked regions.

    Returns:
        answer_pairs: [('A2', 'Weather'), ...]
    """

    answer_pairs = []
    answer_texts = []

    col_index = 2
    while True:
        answer_cell = worksheet.cell(row_index, col_index).value
        if not answer_cell: break
        answer_literal = worksheet_literal.cell(row_index, col_index).value

        cell_answer_pairs = parse_answer_cell(answer_cell)
        answer_pairs.extend(cell_answer_pairs)

        answer_text = parse_answer_literal(answer_literal, worksheet_literal)
        answer_texts.extend(answer_text)

        col_index += 1

    answer_dict = post_complete_pairs(answer_pairs, worksheet, worksheet_literal)
    
    return answer_dict, answer_texts


# %% 2.4 Aggregation Types

def split_aggregation(cell_value):
    """Split by ';' and ',' the aggregation types."""
    cell_agg_types = cell_value.split(';')
    cell_agg_types = [
        c for item in cell_agg_types
        for c in item.split(',')
    ]
    return cell_agg_types


# heuristic match (to mitigate labeling error)
def get_char_dict(text, lower_case=True):
    """Character-Count Dictionary of a piece of Text."""
    char_dict = {}

    for char in text:
        if lower_case == True: char = char.lower()

        if char not in char_dict:
            char_dict[char] = 1
        else:
            char_dict[char] += 1

    return char_dict


def absdiff_char_dict(agg, op):
    """Sum over the absolute differences between two dictionary of characters."""
    agg_char_dict = get_char_dict(agg)
    op_char_dict = get_char_dict(op)

    char_set = set( op_char_dict.keys() ).union( set(agg_char_dict.keys()) )
    for uchar in char_set:
        if uchar not in agg_char_dict: agg_char_dict[uchar] = 0
        if uchar not in op_char_dict: op_char_dict[uchar] = 0
 
    absdiff = 0
    for k, v_agg in agg_char_dict.items():
        v_op = op_char_dict[k]
        v_absdiff = max(v_agg - v_op, v_op - v_agg)
        absdiff += v_absdiff
    return absdiff  


def get_heuristic_operator(agg, verbose=True):
    """Return a possibly heuristically matched operator with the given aggregation."""
    # exact match
    if agg in _OPERATORS: return agg
    # heuristic match
    for op in _OPERATORS:
        if absdiff_char_dict(agg, op) < 3: 
            if verbose == True:
                print(f"Heuristic Match between Agg [{agg}] & Op [{op}]")
            return op


    # no match
    if verbose == True:
        print(f'cannot find any heuristic match for agg {agg}')
    return None


def verify_aggregation_types(aggregation_types, verbose=False):
    """Verify and Refine each parsed aggregation type, according to HMT operators."""

    valid_agg_types = []

    for agg in aggregation_types:
        temp_op = get_heuristic_operator(agg, verbose)
        if temp_op is not None:
            valid_agg_types.append(temp_op)

    return valid_agg_types


def verify_aggregation_types_levenshtein(aggregation_types, verbose=False):
    """Verify and Refine each parsed aggregation type with Levenshtein distance, according to HMT operators."""

    valid_agg_types = []

    for agg in aggregation_types:
        if agg in _OPERATORS:
            valid_agg_types.append(agg)
        elif agg == 'opposite':  # a common alias
            valid_agg_types.append('opp')
        else:
            max_score, candi_op = 0, None
            for op in _OPERATORS:
                score = Levenshtein.ratio(agg, op)
                if score > max_score:
                    max_score = score
                    candi_op = op
            if max_score > 0.7 and candi_op is not None:
                valid_agg_types.append(candi_op)

    return valid_agg_types


# extract agg types
def extract_aggregation(row_index, worksheet):
    """Extract from the annotated aggregation type.
    Should be one of the given vocabs.
    """
    col_index = 2

    aggregation_types = []
    while True:
        candidate_value = worksheet.cell(row_index, col_index).value
        if not candidate_value: break

        cell_agg_types = split_aggregation(candidate_value)
        aggregation_types.extend(cell_agg_types)

        col_index += 1
    
    if len(aggregation_types) == 0:
        aggregation_types.append('none')

    valid_agg_types = verify_aggregation_types_levenshtein(aggregation_types)
    return valid_agg_types
