# -*- coding: utf-8 -*-
import gzip
import logging
import os
import pickle
import re

from delphin.mrs import simplemrs

import eds_modified as eds
from CFGNode import CFGNode
from reader_fix import fix_eds_dash_node_span, fix_eds_prefix_node_span

DATA_DIRECTORY = 'data/export'
TRAIN_FILES = 'data/train.pki'
TEST_FILES = 'data/test.pki'
DEV_FILES = 'data/dev.pki'


def collect_files(return_full_name=False):
    train_files = []
    dev_files = []
    test_files = []
    for directory in os.listdir(DATA_DIRECTORY):
        full_dir = os.path.join(DATA_DIRECTORY, directory)
        match = re.match('wsj([0-9]{2}).', directory)
        if os.path.isdir(full_dir) and match:
            num = int(match.group(1))
            if return_full_name:
                file_iter = iter(
                    os.path.join(full_dir, f) for f in os.listdir(full_dir))
            else:
                file_iter = iter(
                    os.path.join(directory, f) for f in os.listdir(full_dir))
            if num == 20:
                dev_files.extend(file_iter)
            elif num == 21:
                test_files.extend(file_iter)
            else:
                train_files.extend(file_iter)
    return train_files, dev_files, test_files


def load_data(force_new=False):
    if not os.path.exists(TRAIN_FILES) or force_new:
        logger = logging.getLogger('reader')
        train_files, dev_files, test_files = collect_files()
        train_data, failed_count = dump_data(train_files, TRAIN_FILES)
        dev_data, count = dump_data(dev_files, DEV_FILES)
        failed_count += count
        test_data, count = dump_data(test_files, TEST_FILES)
        failed_count += count
        logger.info('train=%s, dev=%s, test=%s', len(train_data),
                    len(dev_data), len(test_data))
        logger.warning('failed count %s/%s', failed_count,
                       len(train_data) + len(dev_data) + len(test_data))
        return train_data, dev_data, test_data
    return (pickle.load(open(TRAIN_FILES, 'rb')),
            pickle.load(open(DEV_FILES, 'rb')),
            pickle.load(open(TEST_FILES, 'rb')))


def dump_data(files, output_file):
    logger = logging.getLogger('reader')
    data = {}
    failed_count = 0
    count = 0
    for fn in files:
        filename = os.path.join(DATA_DIRECTORY, fn)
        try:
            data[fn] = read_gzip_file(filename, True, True)
        except Exception as e:
            logger.error('%s:%s: %s', fn, type(e).__name__, e)
            failed_count += 1

        count += 1
        if count % 1000 == 0:
            logger.info('progress %s: %s/%s', output_file, count, len(files))

    with open(output_file, 'wb') as out:
        pickle.dump(data, out)

    return data, failed_count


def is_graph_a_dag(nodes):
    '''
    判断是否有环
    '''
    in_count = {x: 0 for x in nodes}
    for src in nodes:
        for _, tar in nodes[src]['edges'].items():
            in_count[tar] += 1

    seeds = [x for x in nodes if in_count[x] == 0]
    visited_count = 0
    while len(seeds) != 0:
        src = seeds.pop()
        visited_count += 1
        for _, tar in nodes[src]['edges'].items():
            in_count[tar] -= 1
            if in_count[tar] == 0:
                seeds.append(tar)
    return visited_count == len(nodes)


def make_graph_connected(xmrs, top, nodes):
    '''
    有些不连通可以通过 MRS 的 LBL 信息连起来
    @@ reader.1 @@
    '''
    edges = {x: set() for x in nodes}
    for src in nodes:
        for _, tar in nodes[src]['edges'].items():
            edges[src].add(tar)
            edges[tar].add(src)

    def search_graph():
        visited = set()
        seeds = [top]
        while len(seeds) != 0:
            src = seeds.pop()
            for tar in edges[src]:
                if tar not in visited:
                    visited.add(tar)
                    seeds.append(tar)
        return visited

    visited = search_graph()

    eps = [ep for ep in xmrs.eps() if ep.iv in visited]
    newly_visited = set()
    connected = set()

    def get_ep_order(ep):
        '''出边数量少, 没有被使用过的应该优先考虑
        '''
        order = len(nodes[ep.iv]['edges'])
        if ep.iv in connected:
            order += 1
        return order

    for src in nodes:
        if src in visited or len(nodes[src]['edges']) != 0:
            continue
        try:
            cur_ep = xmrs.ep(xmrs.nodeid(src))
        except Exception:
            continue
        new_eps = [
            ep for ep in eps if ep.label == cur_ep.label and ep.iv != cur_ep.iv
        ]
        if len(new_eps) == 0:
            continue
        new_eps.sort(key=get_ep_order)
        ep = new_eps[0]
        nodes[src]['edges']['ARG1'] = ep.iv
        newly_visited.add(src)
        connected.add(ep.iv)

    visited.update(newly_visited)
    return len(visited) == len(nodes)


def read_gzip_file(filename, check_valid=True, strict=False):
    data = gzip.open(filename).read().decode()
    return read_string(data, filename, check_valid, strict)[1:]


def read_string(data, filename=None, check_valid=True, strict=False):
    logger = logging.getLogger('reader')
    # Use regex to  match uttrance of the file
    matcher = re.search('\[(.*?)\].*?`(.*)\'', data)
    if filename is None:
        filename = matcher.group(1)
    uttr = matcher.group(2).strip()
    data = data[matcher.end() + 1:]

    sections = []
    # We may assume the section is separated by a single '\n'
    for section in data.split('\n\n'):
        section = section.strip()
        # Skip useless sections
        if section == '' or section[0] == '<':
            continue
        sections.append(section)
    sections = sections[1:]  # Skip title
    # First section is UDF Tree
    # udf_tree = derivation.Derivation.from_string(sections[0]).to_dict()
    # Next is CFG Tree
    cfg_root, cfg_leaves = CFGNode.from_sexp(sections[1])
    xmrs = simplemrs.loads_one(sections[2], errors='ignore')
    # Third is MRS
    try:
        eds_graph = eds.loads_one(sections[3])
    except Exception as e:
        eds_graph = eds.Eds.from_xmrs(xmrs)

    graph = eds_graph.to_dict(properties=False)

    if check_valid and not is_graph_a_dag(graph['nodes']):
        raise Exception('not-a-dag')

    if not make_graph_connected(xmrs, **graph):
        if strict:
            raise Exception('not-connected')
        else:
            logger.warning('%s:not-connected', filename)

    nodes = graph['nodes']
    for nodeid in eds_graph.nodeids():
        eds_node = eds_graph.node(nodeid)
        node = nodes[nodeid]
        node['span'] = node['lnk']['from'], node['lnk']['to']
        del node['lnk']
        node.update(pos=eds_node.pred.pos, sense=eds_node.pred.sense)

        try:
            props = xmrs.properties(nodeid)
        except Exception:
            props = {}
        tense = props.get('TENSE', '#')
        if tense == 'UNTENSED':
            tense = '#'
        node['properties'] = [
            tense,
            props.get('NUM', '#'),
            props.get('PERS', '#'),
            props.get('PROG', '#'),
            props.get('PERF', '#')
        ]

        # Remove parallel edges
        edges = node['edges']
        tars = set()
        for elabel in sorted(edges.keys()):
            tar = edges[elabel]
            if tar in tars:
                del edges[elabel]
            else:
                tars.add(tar)

    # Compute span for every node
    CFGNode.compute_span(cfg_root, cfg_leaves, uttr.lower())

    fix_eds_prefix_node_span(filename, uttr, nodes)
    fix_eds_dash_node_span(filename, uttr, nodes)

    return filename, uttr, cfg_root, cfg_leaves, graph


if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO)
    load_data(True)
