from nltk.tokenize.moses import MosesTokenizer

from mark_edge_with_label import get_aligned_nodes
from reader_fix import punctuations
from rule import be_1_forms, do_forms, v_tense_forms, w_forms

special_labels_to_be_presvered = set(['pron', 'poss'])
tokenizer = MosesTokenizer()


def node_label(node):
    label = node['label']
    append_hyphen = False
    if label[0] == '_':
        token = label.split('_')[1]
        pos = token.find('/')
        if pos != -1:
            token = token[:pos]
    elif node.get('carg'):
        token = node['carg']
        if token[-1] == '-':
            append_hyphen = True
            token = token[:-1]
    else:
        token = '<' + label + '>'
    token += '_' + '_'.join(node['properties'])
    if append_hyphen:
        token += ' \\@-\\@'
    return token


def process_sentence(fn, uttr, nodes,
                     cfg_leaves,
                     quote_literals=True,
                     return_list=False):
    '''
    return processed sentence
    '''
    for nodeid in nodes:
        for elabel, targetid in nodes[nodeid]['edges'].items():
            nodes[targetid].setdefault('in', set()).add(nodeid)

    aligned_node, _ = get_aligned_nodes(nodes, cfg_leaves)
    reversed_aligned_node = {}
    for nodeid, index in aligned_node.items():
        reversed_aligned_node.setdefault(index, []).append(nodeid)
    new_uttr = []
    index = 0
    while index < len(cfg_leaves):
        leaf = cfg_leaves[index]
        nodeids = reversed_aligned_node.get(index, [])
        index += 1
        if nodeids:
            nodeids.sort(key=lambda nodeid: nodes[nodeid]['span'][0])
            for nodeid in nodeids:
                node = nodes[nodeid]
                token = node_label(node)
                new_uttr.append(token)
                if node['sense'] == 'modal':
                    for source_id in node.get('in', []):
                        source = nodes[source_id]
                        if source['label'] == 'neg' and \
                           source['span'] == node['span']:
                            new_uttr.append(' <NEG>')
                            break
        else:
            token = uttr[leaf.beg:leaf.end].strip(punctuations)
            if token in w_forms:
                token = '<THAT>'
            elif token in be_1_forms:
                if index < len(cfg_leaves):
                    next_ids = reversed_aligned_node.get(index, [])
                    next_leaf = cfg_leaves[index]
                    next_token = uttr[next_leaf.beg:next_leaf.end]
                    next_token = next_token.strip(punctuations)
                    if next_token == 'being':
                        index += 1
                    elif len(next_ids) == 1 and \
                            nodes[next_ids[0]]['pos'] in ['v', 'u'] and \
                            next_token.endswith('ing'):
                        # 进行时, 忽略 be
                        continue
                token = 'be'
            elif token in do_forms:
                token = 'do'
            elif token in v_tense_forms:
                continue
            elif len(token.strip('-')) == 0:
                continue

            if quote_literals and token[0] != '<':
                token = '"' + token + '"'
            new_uttr.append(token)

    if return_list:
        return new_uttr
    return ' '.join(new_uttr)


def generate_reference(data, lemmas_file, sentences_file):
    with open(lemmas_file, 'w') as lout:
        with open(sentences_file, 'w') as sout:
            for fn in sorted(data.keys()):
                uttr, cfg_tree, cfg_leaves, graph = data[fn]
                lout.write(fn)
                sout.write(fn)
                lout.write(": ")
                sout.write(": ")
                lout.write(process_sentence(fn, uttr, graph['nodes'],
                                            cfg_leaves, quote_literals=False))
                sout.write(tokenizer.tokenize(
                    uttr, agressive_dash_splits=True, return_str=True))
                lout.write('\n')
                sout.write('\n')


if __name__ == '__main__':
    from reader import load_data

    train_data, dev_data, test_data = load_data()
    generate_reference(train_data, 'data/train-lemmas.txt',
                       'data/train-sentences.txt')
    generate_reference(dev_data, 'data/dev-lemmas.txt',
                       'data/dev-sentences.txt')
    generate_reference(test_data, 'data/test-lemmas.txt',
                       'data/test-sentences.txt')
