#include <algorithm>
#include <numeric>

#include "common/transducer/weight_second_order.h"
#include "data_driven_transducer.h"
#include "data_driven_generator_weight.h"

#include "io/token_manager.h"
#include "io/rule.h"
#include "io/utility.h"
#include "io/edsgraph.h"

namespace data_driven {

bool g_generator_use_transition = true;

DynamicRuleGenerator::DynamicRuleGenerator(const std::string& feature_input,
                                           const std::string& feature_output,
                                           TransducerState n_state,
                                           int generator_k_best)
: n_state_(n_state),
  n_score_type_(n_state == TransducerState::TRAIN
                    ? ScoreType::eNonAverage
                    : ScoreType::eAverage),
  state_beam1_(generator_k_best),
  state_beam2_(generator_k_best) {
    dynamic_weight_ = new DynamicWeight(feature_input, feature_output);
    assert(TokenManager::stateCount() <= MAX_STATES_COUNT);
}

DynamicRuleGenerator::~DynamicRuleGenerator() {
    delete dynamic_weight_;
}

void DynamicRuleGenerator::computeHeadIndex(const EdsGraph::Node& node,
                                            bool convert_to_edge_index) {
    auto& rule = rule_set_ptr_->ruleAt(node.rule_index);
    if (rule.tail.equations.empty() ||
        IS_SENTENCE_VAR(rule.tail.equations[0][0])) {
        head_index_ = -1;
        head_from_in_edges_ = false;
        return;
    }

    const int in_state_count = rule.head.in_states.size();
    int var_major = VAR_MAJOR(rule.tail.equations[0][0]);
    if (var_major < in_state_count) {
        head_from_in_edges_ = true;
        head_index_ = var_major;
    } else {
        head_from_in_edges_ = false;
        head_index_ = var_major - in_state_count;
        if (convert_to_edge_index)
            head_index_ = states_to_edges_[head_index_];
    }
}

bool DynamicRuleGenerator::computeSourceIndices(const EdsGraph::Node& node,
                                                int node_index) {
    int indices_buffer[MAX_NODE_IN_EDGE_COUNT];
    int position[MAX_NODE_IN_EDGE_COUNT];
    int states_buffer[MAX_NODE_IN_EDGE_COUNT];

    int in_state_count = 0;
    for (auto source_index : node.in_edges) {
        auto arc = MAKE_ARC(source_index, node_index);
        auto iter = edge_states_ptr_->find(arc);
        if (iter == edge_states_ptr_->end())
            return false;
        auto directed_state = iter->second;
        // EMPTY 的边不考虑
        if (IS_EMPTY_STATE(directed_state))
            continue;
        states_buffer[in_state_count] = directed_state;
        indices_buffer[in_state_count++] = source_index;
    }

    // 根据入边状态的 index 排序
    std::iota(position, position + in_state_count, 0);
    std::sort(position, position + in_state_count,
              [states_buffer](int i, int j) {
                  return states_buffer[i] < states_buffer[j];
              });

    for (int i = 0; i < in_state_count; ++i)
        source_indices_[i] = indices_buffer[position[i]];

#ifndef NDEBUG
    std::sort(states_buffer, states_buffer + in_state_count);
    assert(rule_head_ptr_->in_states.size() == in_state_count &&
           std::equal(states_buffer, states_buffer + in_state_count,
                      rule_head_ptr_->in_states.begin()));
#endif
    return true;
}

bool DynamicRuleGenerator::computeTargetIndices(const EdsGraph::Node& node,
                                                int node_index,
                                                const RuleSet::Tail& rule_tail) {
    int i = 0;
    for (auto& edge : node.out_edges) {
        auto& edge_label = TokenManager::edgeLabelAt(edge.label_index);
        bool found = false;
        int j = 0;
        for (auto state : rule_tail.out_states) {
            if (stringEndsWith(TokenManager::stateAt(state), edge_label)) {
                found = true;
                target_indices_[j] = edge.target_index;
                states_to_edges_[j] = i;
                edges_to_states_[i] = j;
                break;
            }
            ++j;
        }
        i++;
        if (!found)
            return false;
    }
    return true;
}

void DynamicRuleGenerator::updateRuleScore(const EdsGraph::Node& node,
                                           int node_index) {
    int out_state_count = node.out_edges.size();

    assert(out_state_count == ( int )rule_tail_.out_states.size());

    int last_gold_state = -1;
    int last_predict_state = -1;
    for (int i = 0; i < out_state_count; ++i) {
        auto gold_state = STATE_INDEX(gold_states_[i]);
        auto predict_state =
            STATE_INDEX(rule_tail_.out_states[edges_to_states_[i]]);
        bool is_head = !head_from_in_edges_ && head_index_ == i;

        if (gold_state != predict_state) {
            stateScore(node, node_index, i,
                       MAKE_STATE_INDEX(is_head, gold_state), 1);
            stateScore(node, node_index, i,
                       MAKE_STATE_INDEX(is_head, predict_state), -1);

            // LOG_INFO(<< "Unmatch: " << node
            //          << " gold="
            //          << TokenManager::stateAt(gold_state << 6)
            //          << '@' << gold_state
            //          << " predict=" << TokenManager::stateAt(predict_state << 6)
            //          << '@' << predict_state
            //          << " head="
            //          << (edges_to_states_[head_index_] +
            //              (head_from_in_edges_ ? 0 : rule_head_ptr_->in_states.size()))
            //          << " index=" << i);
        } else
            ++n_state_correct_;

        if (i != 0 &&
            (last_gold_state != last_predict_state ||
             gold_state != predict_state)) {
            transitionScore(node, node_index, i,
                            last_gold_state, gold_state, 1);
            transitionScore(node, node_index, i,
                            last_predict_state, predict_state, -1);
        }
        last_gold_state = gold_state;
        last_predict_state = predict_state;
    }
}

void DynamicRuleGenerator::updateEquationScore(const EdsGraph::Node& node,
                                               int node_index) {
    if (rule_tail_.equations.empty()) {
        n_equation_correct_++;
        return;
    }

    auto& rule = rule_set_ptr_->ruleAt(node.rule_index);
    auto& gold_equation = rule.tail.equations[0];
    auto& predict_equation = rule_tail_.equations[0];

    assert(gold_equation[0] == predict_equation[0]);

    int index = 0;
    int total_count = predict_equation.size();
    std::vector< int > pos_list(total_count, -1);
    auto beg = predict_equation.begin();
    auto end = predict_equation.end();
    for (auto var : gold_equation) {
        auto pos = std::find(beg, end, var);
        if (pos != predict_equation.end())
            pos_list[pos - beg] = index++; // 第几个出现
    }

    bool correct = true;
    // 没在 gold 中出现的变量当作不存在
    int predict_index = 0;
    for (int i = 0; i < total_count; ++i) {
        auto var = predict_equation[i];
        auto gold_index = pos_list[i];
        if (gold_index != -1) {
            if (predict_index != gold_index) {
                correct = false;
                int var_major = IS_LEMMA_VAR(var) ? -1 : VAR_MAJOR(var);
                variableScore(node, node_index,
                              var_major, VAR_MINOR(var),
                              (gold_index - predict_index > 0) ? 1 : -1);
            }
            predict_index++;
        }
    }

    if (correct)
        n_equation_correct_++;

#ifndef NDEBUG
// if (!correct) {
//     printStates(std::cerr, rule_head_ptr_->in_states);
//     std::cerr << ' ' << node << ' ';
//     printStates(std::cerr, rule_tail_.out_states);
//     std::cerr << std::endl;
//     printEquations(std::cerr << "gold: ",
//                    rule_set_ptr_->ruleAt(node.rule_index).tail.equations);
//     std::cerr << std::endl;
//     printEquations(std::cerr << "predict: ", rule_tail_.equations);
//     std::cerr << std::endl;
// }
#endif
}

void DynamicRuleGenerator::decodeVertibi(const EdsGraph::Node& node,
                                         int node_index) {
    int out_state_count = rule_head_ptr_->out_state_count;

    auto beam_ptr1 = &state_beam1_;
    auto beam_ptr2 = &state_beam2_;
    state_beam2_.clear();
    state_beam1_.shrinkToOne() = {0, -1};

    // LOG_INFO(<< "head_index = " << head_index_);
    // LOG_INFO(<< rule_set_ptr_->ruleAt(node.rule_index));

    auto& node_label = TokenManager::nodeLabelAt(node.label_index);
    for (int i = 0; i < out_state_count; ++i) {
        auto& edge = node.out_edges[i];
        auto& edge_label = TokenManager::edgeLabelAt(edge.label_index);
        bool is_head = !head_from_in_edges_ && head_index_ == i;
        // TODO: maybe need random access states
        // 1. 边必须匹配
        // 2. 特殊结点只能固定一些 label (loc_nonsp:ARG1)
        // 3. 空状态不能是 head
        for (auto j : TokenManager::statesOfEdge(edge_label)) {
            if (TokenManager::stateMatch(j, node_label) &&
                !(j < EMPTY_STATE_COUNT && is_head)) {
                auto state_j = MAKE_STATE_INDEX(is_head, j);
                tscore local_best_score = INT64_MIN;
                int local_last_state = -1;
                for (auto item : *beam_ptr1) {
                    tscore score = item->first + stateScore(node, node_index, i, state_j, 0);
                    if (i > 0) { // transition score
                        bool is_head = !head_from_in_edges_ && head_index_ == (i - 1);
                        auto last_state_j = MAKE_STATE_INDEX(is_head, item->second);
                        score += transitionScore(node, node_index, i,
                                                 last_state_j, state_j, 0);
                    }
                    if (score > local_best_score) {
                        local_best_score = score;
                        local_last_state = item->second;
                    }
                }
                beam_ptr2->insert({local_best_score, j});
                vertibi_bp_[i][j] = local_last_state;
            }
        }

        if (beam_ptr2->size() == 0) {
            LOG_DEBUG(<< "Edge label: "
                      << edge_label << " of " << node
                      << " doesn't occur in training data");
            int j = TokenManager::indexOfState("X:" + edge_label);
            beam_ptr2->shrinkToOne() = {0, j};
            auto item = *std::max_element(beam_ptr1->begin(), beam_ptr1->end());
            vertibi_bp_[i][j] = item->second;
        }

        beam_ptr1->clear();
        swap(beam_ptr1, beam_ptr2);
    }

    auto& score_state_pair = beam_ptr1->best();
    states_score_ = score_state_pair.first;
    // 重构路径
    int state_index = score_state_pair.second;
    for (int i = out_state_count - 1; i >= 0; --i) {
        bool is_head = !head_from_in_edges_ && head_index_ == i;
        rule_tail_.out_states[i] = TokenManager::makeState(state_index, !is_head);
        state_index = vertibi_bp_[i][state_index];
    }

    // 需要按照 state 的顺序而不是 edge 的顺序
    std::sort(rule_tail_.out_states.begin(),
              rule_tail_.out_states.end());

    // LOG_INFO(<< "Best score: " << states_score_);
    // printStates(std::cerr, rule_tail_.out_states);
    // std::cerr << std::endl;
}

void DynamicRuleGenerator::generateEquation(const EdsGraph::Node& node,
                                            int node_index) {
    static const int size = MAX_NODE_IN_EDGE_COUNT +
                            MAX_NODE_OUT_EDGE_COUNT + 1;
    static int scores[size];
    static int indices[size];
    static int variables[size];

    rule_tail_.equations.clear();
    int in_state_count = rule_head_ptr_->in_states.size();
    int out_state_count = rule_tail_.out_states.size();

    int total_count = 0;
    for (int i = 0; i < in_state_count; ++i) {
        auto state = rule_head_ptr_->in_states[i];
        int var_count = STATE_VAR_COUNT(state);
        if (!head_from_in_edges_ || head_index_ != i)
            for (int j = 0; j < var_count; ++j) {
                variables[total_count] = MAKE_VAR(i, j);
                scores[total_count++] = variableScore(node, node_index,
                                                      i, j, 0);
            }
    }

    for (int i = 0; i < out_state_count; ++i) {
        auto state = rule_tail_.out_states[i];
        int var_count = STATE_VAR_COUNT(state);
        if (head_from_in_edges_ || head_index_ != states_to_edges_[i])
            for (int j = 0; j < var_count; ++j) {
                variables[total_count] = MAKE_VAR(i + in_state_count, j);
                scores[total_count++] = variableScore(node, node_index,
                                                      i + in_state_count, j, 0);
            }
    }

    if (total_count == 0 && head_index_ < 0)
        return;

    variables[total_count] = LEMMA_VAR;
    scores[total_count++] = variableScore(node, node_index, LEMMA_VAR, -1, 0);

#ifndef NDEBUG
// if (n_state_ == TransducerState::TRAIN) {
//     for (int i = 0; i < total_count; ++i) {
//         auto var = variables[i];
//         std::cerr << '[';
//         if (IS_LEMMA_VAR(var))
//             std::cerr << -1;
//         else
//             std::cerr << '(' << VAR_MAJOR(var) << ", "
//                       << VAR_MINOR(var) << ')';
//         std::cerr << ", " << scores[i] << "] ";
//     }
//     std::cerr << std::endl;
// }
#endif

    // 根据得分的大小来排序
    std::iota(indices, indices + total_count, 0);
    std::sort(indices, indices + total_count,
              [](int i, int j) {
                  return scores[i] < scores[j];
              });

    if (head_index_ == -1) { // 句子
        rule_tail_.equations.push_back({SENTENCE_VAR});
    } else {
        int var_count, var_major;
        if (head_from_in_edges_) {
            var_major = head_index_;
            var_count = STATE_VAR_COUNT(rule_head_ptr_->in_states[var_major]);
        } else {
            var_major = edges_to_states_[head_index_];
            var_count = STATE_VAR_COUNT(rule_tail_.out_states[var_major]);
            var_major += in_state_count;
        }
        for (int var_minor = 0; var_minor < var_count; ++var_minor)
            rule_tail_.equations.push_back({MAKE_VAR(var_major, var_minor)});
    }

    auto& equation = rule_tail_.equations.front();
    for (int i = 0; i < total_count; ++i)
        equation.push_back(variables[indices[i]]);
}

void DynamicRuleGenerator::generateRuleTail(const EdsGraph::Node& node,
                                            int node_index) {
    rule_tail_.out_states.clear();

    int out_state_count = rule_head_ptr_->out_state_count;
    if (out_state_count == 0)
        return;

    rule_tail_.out_states.resize(out_state_count);

    assert(out_state_count < MAX_NODE_OUT_EDGE_COUNT);

    decodeVertibi(node, node_index);
    computeTargetIndices(node, node_index, rule_tail_);
    // LOG_INFO(<< "head_index = " << head_index_);
    // LOG_INFO(<< rule_set_ptr_->ruleAt(node.rule_index));
    // LOG_INFO(<< "Best score: " << states_score_);
    // printStates(std::cerr, rule_tail_.out_states);
    // std::cerr << std::endl;
}

void DynamicRuleGenerator::generateRuleTailAuto(const EdsGraph::Node& node,
                                                int node_index) {
    rule_tail_.out_states.clear();

    int out_state_count = rule_head_ptr_->out_state_count;
    if (out_state_count == 0)
        return;

    rule_tail_.out_states.resize(out_state_count);

    assert(out_state_count < MAX_NODE_OUT_EDGE_COUNT);

    if (!head_from_in_edges_) {
        int head_index = -1;
        int states_buffer[MAX_NODE_OUT_EDGE_COUNT];
        tscore best_score = INT64_MIN;
        // 尝试每种可能, head_index = -1 合成句子
        for (int i = -1; i < out_state_count; ++i) {
            head_index_ = i;
            decodeVertibi(node, node_index);
            if (states_score_ > best_score) {
                best_score = states_score_;
                head_index = head_index_;
                std::copy_n(rule_tail_.out_states.begin(),
                            out_state_count,
                            states_buffer);
            }
        }

        head_index_ = head_index;
        std::copy_n(states_buffer,
                    out_state_count,
                    rule_tail_.out_states.begin());
    } else
        decodeVertibi(node, node_index);


    computeTargetIndices(node, node_index, rule_tail_);
}

void DynamicRuleGenerator::printInformation() {
    LOG_INFO(<< "Round " << n_round_);
    LOG_INFO(<< "States (correct/total) = "
             << n_state_correct_ << '/'
             << n_total_states_ << ' '
             << ( float )n_state_correct_ / n_total_states_);
    LOG_INFO(<< "Nodes (eq-correct/eq-covered/total) = "
             << n_equation_correct_ << '/'
             << n_equation_generated_ << '/'
             << n_total_nodes_ << ' '
             << ( float )n_equation_correct_ / n_equation_generated_ << ' '
             << ( float )n_equation_generated_ / n_total_nodes_);
    std::cout.flush();
}

void DynamicRuleGenerator::trainRuleGenerator(const EdsGraph& graph,
                                              const RuleSet& rule_set,
                                              int round) {
    static int source_indices[MAX_NODE_IN_EDGE_COUNT];

    rule_set_ptr_ = &rule_set;
    graph_ptr_ = &graph;
    n_round_ = round;

    EdgeStateMap edge_states;
    edge_states_ptr_ = &edge_states;
    source_indices_ = source_indices;

    LOG_DEBUG(<< "Train " << graph.filename);

    for (auto& item : graph.edges) {
        auto& source = graph.nodes[ARC_SRC(item.first)];
        auto& rule = rule_set.ruleAt(source.rule_index);
        auto& edge_label = TokenManager::edgeLabelAt(item.second);
        computeHeadIndex(source, false /*convert_to_edge_index*/);

        int i = 0;
        for (auto state : rule.tail.out_states) {
            if (stringEndsWith(TokenManager::stateAt(state), edge_label))
                edge_states[item.first] = state;
            ++i;
        }
    }

    int node_count = graph.nodes.size();
    for (int node_index = 0; node_index < node_count; ++node_index) {
        auto& node = graph.nodes[node_index];

        auto& rule = rule_set.ruleAt(node.rule_index);
        rule_head_ptr_ = &rule.head;

        if (!computeTargetIndices(node, node_index, rule.tail)) {
            LOG_ERROR(<< "Invalid training data " << graph.filename
                      << ": can not match out states and out edges");
            break;
        }

        computeHeadIndex(node, true /*convert_to_edge_index*/);

        n_total_states_ += node.out_edges.size();
        ++n_total_nodes_;

        for (int i = 0; i < rule.head.out_state_count; ++i)
            gold_states_[i] = rule.tail.out_states[edges_to_states_[i]];

        generateRuleTail(node, node_index);
        updateRuleScore(node, node_index);

        // 使用标准的结果
        rule_tail_.out_states = rule.tail.out_states;

        if (rule.tail.equations.size() != 1 ||
            !computeTargetIndices(node, node_index, rule_tail_))
            continue;

        generateEquation(node, node_index);
        ++n_equation_generated_;
        updateEquationScore(node, node_index);
    }

    if (round % OUTPUT_STEP == 0)
        printInformation();
}
}
