#include <limits>
#include <algorithm>
#include <fstream>

#include "include/basic.h"
#include "edsgraph.h"
#include "token_manager.h"


std::istream& operator>>(std::istream& is, EdsGraph& graph) {
    int node_count;
    std::string token;

    is >> graph.filename;
    is.ignore(std::numeric_limits< std::streamsize >::max(), '\n');
    std::getline(is, graph.sentence);
    is >> node_count;

    if (!is || graph.sentence.empty()) return is;

    graph.nodes.resize(node_count);
    for (int i = 0; i < node_count; ++i) {
        auto& node = graph.nodes[i];
        int edge_count, pos_tag;

        is >> token;
        node.label_index = TokenManager::indexOfNodeLabel(token);

        node.is_special = token.empty() || token[0] != '_';

        is >> token;
        auto pos = token.find('/');
        if (pos != std::string::npos)
            token.erase(pos);
        node.lemma_index = TokenManager::indexOfLemma(token);
        is >> pos_tag;
        node.pos_tag = pos_tag;
        is >> token;
        node.sense_index = TokenManager::indexOfSense(token);
        is >> node.carg;

        for (int i = 0; i < 5; ++i)
            is >> node.properties[i];

        is >> edge_count;

        node.out_edges.resize(edge_count);
        for (int j = 0; j < edge_count; ++j) {
            auto& edge = node.out_edges[j];
            is >> token;
            edge.label_index = TokenManager::indexOfEdgeLabel(token);
            is >> edge.target_index;

            graph.edges[MAKE_ARC(i, edge.target_index)] = edge.label_index;
        }

        // 按照 label 的字典序排列
        std::sort(node.out_edges.begin(), node.out_edges.end(),
                  [](const EdsGraph::Edge& e1, const EdsGraph::Edge& e2) {
                      return TokenManager::edgeLabelAt(e1.label_index) <
                             TokenManager::edgeLabelAt(e2.label_index);
                  });

        is >> node.rule_index;

        node.in_edges.clear();
    }

    for (int i = 0; i < node_count; ++i) {
        auto& node = graph.nodes[i];
        for (auto& edge : node.out_edges)
            graph.nodes[edge.target_index].in_edges.push_back(i);
    }
    return is;
}

std::ostream& operator<<(std::ostream& os, const EdsGraph& graph) {
    os << graph.nodes.size() << '\n';
    for (auto& node : graph.nodes) {
        os << "Label = '"
           << TokenManager::nodeLabelAt(node.label_index) << "' "
           << "Lemma = '"
           << TokenManager::lemmaAt(node.lemma_index) << "' "
           << "POSTag = '" << ( char )node.pos_tag << "' "
           << "Sense = '"
           << TokenManager::senseAt(node.sense_index) << "' " << '\n';

        os << "OutEdges: " << node.out_edges.size() << ' ';
        for (auto& edge : node.out_edges)
            os << '<' << TokenManager::edgeLabelAt(edge.label_index)
               << ", " << edge.target_index << "> ";

        os << '\n'
           << "InEdges: " << node.in_edges.size() << ' ';
        for (auto& source_index : node.in_edges)
            os << source_index << ' ';
        os << '\n'
           << "RuleIndex = " << node.rule_index << '\n';
    }
    return os;
}

std::string EdsGraph::Node::toString() const {
    std::string str(TokenManager::lemmaAt(lemma_index));
    str.append(TokenManager::nodeLabelAt(label_index));
    str.push_back('(');
    str.append(carg);
    str.push_back(')');
    return str;
}

std::ostream& operator<<(std::ostream& os, const EdsGraph::Node& node) {
    return os << node.toString();
}

void loadGraphsFromFile(const std::string& input_file,
                        std::vector< EdsGraph >& graphs,
                        std::vector< int >& random) {
    int graph_count;

    LOG_INFO(<< "Loading Graphs ...");
    std::ifstream is(input_file);
    if (!is) {
        LOG_ERROR(<< "Can't find file " << input_file);
        return;
    }

    is >> graph_count;
    graphs.reserve(graph_count);
    random.reserve(graph_count);
    EdsGraph graph;
    for (int i = 0; i < graph_count; ++i) {
        is >> graph;
        if (std::any_of(graph.nodes.begin(), graph.nodes.end(),
                        [](auto& node) { return node.rule_index == -1; }))
            LOG_ERROR(<< "Invalid training graph");
        else {
            random.push_back(graphs.size());
            graphs.push_back(std::move(graph));
        }
    }

    is.close();
    std::random_shuffle(random.begin(), random.end());
    std::stable_sort(random.begin(), random.end(),
                     [&graphs](int i, int j) {
                         return graphs[i].nodes.size() < graphs[j].nodes.size();
                     });
    LOG_INFO(<< "Loaded " << graph_count << " Graphs");
}
