import logging
import re

from nltk.stem import WordNetLemmatizer

from span import is_span_overlap_span

wordnet_lemmatizer = WordNetLemmatizer()

regex_bad = re.compile('bad|worse|worst')
regex_good = re.compile('good|better|best')

regex_not_word = re.compile(r'[.,\-"?!();#{}+]+')

card_transform = {
    x[0]: re.compile('|'.join(x))
    for x in [['1', 'one', 'a'], ['2', 'two'], ['3', 'three'],
              ['4', 'four'], ['5', 'five'],
              ['6', 'six'], ['7', 'seven'], ['8', 'eight'],
              ['9', 'nine'], ['10', 'ten'],
              ['11', 'eleven'], ['12', 'twelve'],
              ['1000000', 'million'],
              ['1000000000', 'billion'],
              ['1000000000000', 'trillion']]}

# 前缀
regex_prefix = re.compile('_([a-z]+-|mid)_')

# 可以作为正常的结点
special_labels_as_normal_node = ['named', 'card', 'yofc',
                                 'fraction', 'season',
                                 'year_range', 'mofy', 'much-many_a']
# 匹配 'b.a.t' 'U.S.' 这类
regex_special_named = re.compile('^[a-z](\.[a-z])+\.?$')
# 匹配 SKR200 这类
regex_skr = re.compile('^([^\w]*skr)([0-9\.]+[^\w]*)$')

punctuations = '.,"?!;() '

# 特殊的情况
token_lemma_special_case = set([('ft', 'foot'),
                                ('gray', 'grey'),
                                ('offshore', 'off-shore'),
                                ('n.m', 'new mexico'),
                                ('hi', 'hawaii'),
                                ('vice', 'co')])


def get_lemma(token, pos):
    '''
    使用 nltk 取找词的原型
    '''
    token = token.strip(punctuations)
    try:
        return wordnet_lemmatizer.lemmatize(token, pos)
    except Exception:
        return token


def is_normal_node(node):
    label = node['label']
    return label.startswith('_') or label in special_labels_as_normal_node


def is_prefix_node(node_label):
    match = re.match(regex_prefix, node_label)
    return match and match.group(1) != 'up-'


def preprocess_edges(edges):
    if 'L-INDEX' in edges and edges['L-INDEX'] == edges.get('R-INDEX'):
        # 平行边, 情况 1
        if 'R-HNDL' not in edges and 'L-HNDL' in edges:
            del edges['L-INDEX']
        else:
            del edges['R-INDEX']
    # 平行边, 情况 2
    if 'L-HNDL' in edges and edges['L-HNDL'] == edges.get('L-INDEX'):
        del edges['L-HNDL']
    if 'R-HNDL' in edges and edges['R-HNDL'] == edges.get('R-INDEX'):
        del edges['R-HNDL']


def get_overlapped_nodes(uttr, nodes):
    '''
    返回图里所有的有区域重叠的结点
    '''
    overlap_nodes = {}
    for nid in nodes:
        node = nodes[nid]
        beg, end = node['span']
        string = uttr[beg:end].lower().strip(punctuations)
        if (is_normal_node(node) or node['label'] == 'neg') and \
           (string.find('-') != -1 or
            string.find('/') != -1 or
            re.match(regex_skr, string) or
            re.match(regex_special_named, string) or
                string == 'the like' or
                string == 'everytime'):
            overlap_nodes[nid] = 1
            span = nodes[nid]['span']
            for nid_ in overlap_nodes:
                if nid != nid_ and \
                   is_span_overlap_span(nodes[nid_]['span'], span):
                    overlap_nodes[nid_] += 1
                    overlap_nodes[nid] += 1

    # 小的 span 优先
    def node_span_width(n):
        s = nodes[n]['span']
        return s[1] - s[0]
    overlap_nodes = sorted([nid for nid, count in overlap_nodes.items()
                            if count > 1],
                           key=node_span_width)
    _per_p_nodes = []
    overlap_nodes_1 = []
    # _per_p 结点最后处理
    for nid in overlap_nodes:
        if nodes[nid]['label'] in ['_per_p', '_and_c']:
            _per_p_nodes.append(nid)
        else:
            overlap_nodes_1.append(nid)
    return overlap_nodes_1 + _per_p_nodes


def fix_eds_prefix_node_span(filename, uttr, nodes):
    '''
    处理前缀 pre-, mid-, re- 这种
    @@ reader.3 @@
    '''
    logger = logging.getLogger('reader_fix')
    for nid in nodes:
        node = nodes[nid]
        node_label = node['label']

        match = re.match(regex_prefix, node_label)
        arg1 = node['edges'].get('ARG1')
        # up- 不是前缀
        if match and arg1 and match.group(1) != 'up-':  # 前缀
            logger.debug('%s:prefix: %s', filename, match.group(0))

            arg1_node = nodes[arg1]
            # arg1 的 span 改为去掉前缀的
            span = node['span']
            beg = span[0] + len(match.group(1))
            node['span'] = span[0], beg

            span = arg1_node['span']
            assert(beg < span[1])
            arg1_node['span'] = beg, span[1]


def fix_eds_dash_node_span(filename, uttr, nodes):
    '''
    处理 '-' '/' 连接的词组
    @@ reader.4 @@
    '''
    logger = logging.getLogger('reader_fix')

    overlap_nodes = get_overlapped_nodes(uttr, nodes)

    if len(overlap_nodes) > 0:
        logger.debug('%s:overlapped %s',
                     filename,
                     [nodes[x]['label'] for x in overlap_nodes])

    # 有时出现 la-la 这种两个词一样的并列, 需要记录哪些 span 被占用了
    covered_spans = set()
    for nid in overlap_nodes:
        node = nodes[nid]
        node_label = node['label']
        node_pos = node['pos']

        is_number = node_label in ['card', 'yofc', 'fraction']
        is_noun = node_label in ['named', 'mofy', 'season']
        is_neg = node_label == 'neg'

        # 处理连词 '-' '/'
        if node_label.startswith('_') or is_number or is_noun or is_neg:
            span = node['span']
            carg = node.get('carg')

            string = uttr[span[0]:span[1]].lower()

            # 清理两侧的标点, 但是记录位置
            s, e = 0, len(string)
            while s < e and not string[s].isalnum():
                s += 1
            while s < e and not string[e - 1].isalnum():
                e -= 1

            striped_string = string[s:e]

            if striped_string.startswith('y-mp'):  # dirty work
                sep = '/'
            else:
                sep = '-'
            tokens = list(x for x in striped_string.split(sep) if x != '')

            # 处理 skr200
            match = re.match(regex_skr, striped_string)
            if match:
                tokens = match.group(1), match.group(2)
                sep = ''
            elif striped_string == 'everytime':
                tokens = 'every', 'time'
                sep = ''

            # 处理 U.S, B.A.T 这种
            match = re.match(regex_special_named, striped_string)
            if match:
                tokens = list(x for x in striped_string.split('.') if x != '')
                sep = '.'

            # 数据集中出现 the like 10+ 次
            if striped_string == 'the like':
                tokens = 'the', 'like'
                sep = ' '

            # '/' 在后面考虑
            if len(tokens) <= 1:
                tokens = list(x for x in striped_string.split('/')
                              if x != '')
                sep = '/'

            if len(tokens) >= 2 and node_label == '_per_p':
                arg2 = node['edges'].get('ARG2')
                if arg2:
                    # _per_p 结点 span 由 'xx/yy' 改为 '/'
                    beg = nodes[arg2]['span'][0]
                    node['span'] = beg - 1, beg
                continue
            elif len(tokens) >= 2 and node_pos == 'c' and sep == '/':
                # _and_c 结点
                arg1 = node['edges'].get('L-INDEX') or \
                    node['edges'].get('L-HNDL')
                if arg1:
                    end = nodes[arg1]['span'][1]
                    node['span'] = end, end + 1  # '/'
                continue

            if len(tokens) <= 1:
                continue

            if is_number:
                lemma = card_transform.get(carg, carg)
                logger.debug('%s:number: %s -> %s', filename, carg, lemma)
            elif is_noun:
                lemma = carg.lower().replace('+', ' ').replace('_', ' ')
                # 去掉结尾的 '-'
                if lemma[-1] == sep:
                    lemma = lemma[:-1]
            elif is_neg:
                lemma = 'non'
            else:
                lemma = node_label.split('_')[1].lower().replace('’', '\'')

            if node_pos == 'u':
                pos = lemma.find('/')
                if pos != -1:
                    lemma = lemma[:pos]

            if lemma == 'bad':
                lemma = regex_bad
            elif lemma == 'good':
                lemma = regex_good

            lemma_splited_by_plus = lemma.split('+')
            if len(lemma_splited_by_plus) > 1:  # 处理 buy+out 结点
                beg = string.find(sep.join(lemma_splited_by_plus))
                if beg != -1:
                    end = beg + len(lemma)
                    string = string[:beg] + \
                        string[beg:end].replace(sep, '+') + string[end:]
                    striped_string = string.strip(punctuations)
                    tokens = list(x for x in striped_string.split(sep)
                                  if x != '')

            token_lens = list(len(x) for x in tokens)
            token_lens[0] += s
            token_lens[-1] += len(string) - e

            logger.debug('%s:%s "%s"', filename, tokens, lemma)
            beg = span[0]
            for token, token_len in zip(tokens, token_lens):
                token_lemma = get_lemma(token, node_pos)
                # 判断是否匹配
                if isinstance(lemma, str):
                    pred = token == lemma
                    pred = pred or token_lemma.startswith(lemma)
                    pred = pred or (lemma.startswith(token_lemma) and
                                    lemma != 'toyko')  # dirty work
                    pred = pred or \
                        (token_lemma, lemma) in token_lemma_special_case
                    pred = pred or \
                        (token_lemma.endswith('ied') and
                         token_lemma[:-3] == lemma[:-1])
                else:
                    pred = re.match(lemma, token_lemma)
                new_span = beg, beg + token_len
                if pred and \
                    not any(is_span_overlap_span(x, new_span)
                            for x in covered_spans):
                    node['span'] = new_span
                    covered_spans.add(new_span)

                    logger.debug('    %s -> %s', span, node['span'])

                    real_token = uttr[beg:beg + token_len].lower()
                    if token != real_token:
                        logger.debug('    `%s` <-> `%s`',
                                     token, real_token)

                    break
                beg += token_len + len(sep)
