#include <cmath>
#include <stack>
#include <algorithm>
#include <ctime>
#include <cstdlib>
#include <unordered_set>
#include <numeric>

#include "common/transducer/weight_second_order.h"
#include "data_driven_transducer.h"

#include "io/token_manager.h"
#include "io/rule.h"
#include "io/utility.h"
#include "io/edsgraph.h"
#include "io/drawer.h"

namespace data_driven {
bool g_predict_rule_head = false;
bool g_special_nodes_order = false;

BeamSearchTransducer::BeamSearchTransducer(const std::string& feature_input,
                                           const std::string& feature_output,
                                           TransducerState n_state,
                                           int beam_size)
: DAGTransducerBase(n_state),
  solver(this),
  generator_ptr_(nullptr),
  running_time_(0),
  return_value_(0),
  beam1_(beam_size),
  beam2_(beam_size),
  beam3_(beam_size) {
    weight_ = new WeightSecondOrder(feature_input, feature_output);
    // 初始化 solver
    solver.initialize();
}

BeamSearchTransducer::~BeamSearchTransducer() {
    delete weight_;
}

void BeamSearchTransducer::update() {
    if (!current_status_ptr_)
        return;
    // 更新结点的权重
    for (auto i : nodes_update_) {
        auto& node = graph_ptr_->nodes[i];
        auto& predict_index = current_status_ptr_->rule_choices[i];
        if (predict_index != -1) { // 预测不正确的结点
            if (node.rule_index != predict_index) {
                // 正确的规则权重 +1
                getOrUpdateScore(i, node.rule_index, 1);
                // 错误的规则权重 -1
                getOrUpdateScore(i, predict_index, -1);
            } else
                ++n_correct_nodes_;
            ++n_covered_nodes_;
        }
    }
    nodes_update_.clear();
}

void BeamSearchTransducer::computeAllIndices() {
    int i = 0;
    for (auto& node : graph_ptr_->nodes) {
        // 计算所有状态和边的对应关系
        int rule_index = current_status_ptr_->rule_choices[i];
        assert(rule_index != -1);
        computeSourceIndices(node, i);
        rule_tail_ptr_ = &ruleAt(rule_index).tail;
        computeTargetIndices(node, i);
        ++i;
    }
}

void BeamSearchTransducer::solve() {
    if (current_status_ptr_) {
        computeAllIndices();
        LOG_DEBUG(<< graph_ptr_->filename << ":");
        solver.solve(result_);
        LOG_DEBUG(<< result_);
    }
}

// 计算入边状态集合, 以及状态对应哪个父节点
void BeamSearchTransducer::computeSourceIndices(const EdsGraph::Node& node,
                                                int node_index) {
    int indices_buffer[MAX_NODE_IN_EDGE_COUNT];
    int position[MAX_NODE_IN_EDGE_COUNT];
    int states_buffer[MAX_NODE_IN_EDGE_COUNT];

    int in_state_count = 0;
    for (int source_index : node.in_edges) {
        auto arc = MAKE_ARC(source_index, node_index);

        assert(current_status_ptr_->edge_states.find(arc) !=
               current_status_ptr_->edge_states.end());

        auto directed_state = current_status_ptr_->edge_states[arc];
        // EMPTY 的边不考虑
        if (IS_EMPTY_STATE(directed_state))
            continue;
        states_buffer[in_state_count] = directed_state;
        indices_buffer[in_state_count++] = source_index;
    }

    in_state_counts_[node_index] = in_state_count;

    int* node_source_indices = source_indices_[node_index];
    // 根据入边状态的 index 排序
    std::iota(position, position + in_state_count, 0);
    std::sort(position, position + in_state_count,
              [states_buffer](int i, int j) {
                  return states_buffer[i] < states_buffer[j];
              });

    auto& in_states = rule_head_.in_states;
    in_states.resize(in_state_count);
    for (int i = 0; i < in_state_count; ++i) {
        node_source_indices[i] = indices_buffer[position[i]];
        in_states[i] = states_buffer[position[i]];
    }

    std::sort(in_states.begin(), in_states.end());
}

// 计算出边状态集合, 以及状态对应哪个子节点
bool BeamSearchTransducer::computeTargetIndices(const EdsGraph::Node& node,
                                                int node_index) {
    bool found;
    int m = node.out_edges.size();
    int* node_target_indices = target_indices_[node_index];
    for (int i = 0; i < m; ++i) {
        auto& edge = node.out_edges[i];
        found = false;
        int j = 0;
        for (auto state : rule_tail_ptr_->out_states) {
            // 状态的后缀必须和边 label 的一致, 比如 NP:ARG1 只能在
            // ARG1 边上出现
            if (stringEndsWith(TokenManager::stateAt(state),
                               TokenManager::edgeLabelAt(edge.label_index))) {
                found = true;
                node_target_indices[j] = edge.target_index;
                local_edge_states_[i] = state;
                break;
            }
            ++j;
        }
        // 必须每条边有唯一的对应, 这里假设任意结点的出边的 label 不一
        // 致, 即不会有两个 ARG1 出现
        if (!found)
            return false;
    }
    return true;
}

void BeamSearchTransducer::computeRuleHead(const EdsGraph::Node& node,
                                           int node_index) {
    rule_head_.out_state_count = node.out_edges.size();
    rule_head_.label_index = node.label_index;

    computeSourceIndices(node, node_index);
}

bool BeamSearchTransducer::computeRuleTail(const EdsGraph::Node& node,
                                           int node_index) {
    is_current_rule_top_ = false;
    stream_target_index_ = -1;

    if (!computeTargetIndices(node, node_index))
        return false;

    auto& equations = rule_tail_ptr_->equations;
    if (equations.empty()) // 没有变量的边, 不用继续检查
        return true;

    // 第几个变量
    int var = equations[0][0];

    // 当前规则产生 句子 但是 top 已经找到 (默认句子只在 top 结点合成)
    if (IS_SENTENCE_VAR(var) &&
        current_status_ptr_->top_index != -1) {
        return false;
    }

    // 数据流的目标结点
    int var_major = VAR_MAJOR(var);
    if (IS_SENTENCE_VAR(var)) {
        is_current_rule_top_ = true;
        return true;
    }

    int target_index = targetOrSourceIndex(node_index, var_major);
    stream_target_index_ = target_index;

    // 这条规则不会使得, 数据流有环
    if (!current_status_ptr_->node_reached.test(node_index))
        return true;

    do {
        // 有环
        if (target_index == node_index)
            return false;
        target_index = current_status_ptr_->stream_target[target_index];
    } while (target_index >= 0);

    return true;
}

void BeamSearchTransducer::computeNextStatus(const EdsGraph::Node& node,
                                             int node_index) {
    if (is_current_rule_top_)
        next_status_.top_index = node_index;

    if (stream_target_index_ != -1)
        next_status_.node_reached.set(stream_target_index_);

    assert(next_status_.stream_target[node_index] == -1);

    next_status_.stream_target[node_index] = stream_target_index_;
    next_status_.rule_choices[node_index] = rule_tail_ptr_->rule_index;

    int i = 0;
    for (auto& edge : node.out_edges) {
        auto arc = MAKE_ARC(node_index, edge.target_index);
        next_status_.edge_states[arc] = local_edge_states_[i++];
#ifndef NDEBUG
// std::cerr << '(' << node_index << " -> " << edge.target_index << ')' << std::endl;
#endif
    }
}

int BeamSearchTransducer::computeRuleHeadIndex(int node_index,
                                               int* head_indices) {
    int in_state_count = rule_head_.in_states.size();
    int* node_source_indices = source_indices_[node_index];
    int head_count = 0;
    for (int i = 0; i < in_state_count; ++i) {
        auto arc = MAKE_ARC(node_source_indices[i], node_index);

        assert(current_status_ptr_->edge_states.find(arc) !=
               current_status_ptr_->edge_states.end());

        if (IS_EDGE_REVERSED(current_status_ptr_->edge_states[arc])) {
            if (head_indices == nullptr)
                return i;
            head_indices[head_count++] = i;
        }
    }
    if (head_indices == nullptr)
        return -1;
    return head_count;
}

void BeamSearchTransducer::nextStep(const EdsGraph::Node& node,
                                    int node_index,
                                    Beam* beam_ptr,
                                    const RuleSet::Tails& rule_tails,
                                    bool rules_are_extended) {
    for (auto rule_tail_ptr : rule_tails) {
        rule_tail_ptr_ = rule_tail_ptr;

        if (computeRuleTail(node, node_index)) {
            return_value_ = 0;

#ifdef DECODER_CHECK
            // 总是给正确的答案高分
            if (rule_tail_ptr_->rule_index == node.rule_index)
                return_value_ = 10;
            else
                return_value_ = 10;
#else
            int rule_index = -1;
            if (rules_are_extended)
                rule_index = prototypeIndex(rule_tail_ptr->rule_index);
            else
                rule_index = rule_tail_ptr->rule_index;
            getOrUpdateScore(node_index, rule_index, 0);

#endif
            return_value_ += current_status_ptr_->score;
            next_status_.score = return_value_;

            if (!beam_ptr->tryInsert(next_status_))
                continue;

            next_status_ = *current_status_ptr_;
            next_status_.score = return_value_;

            computeNextStatus(node, node_index);
            beam_ptr->insert(next_status_);
        }
    }
}

void BeamSearchTransducer::nextStepDynamic(const EdsGraph::Node& node,
                                           int node_index,
                                           Beam* beam_ptr) {
    rule_tail_ptr_ = generator_ptr_->ruleTailPtr();
    if (computeRuleTail(node, node_index)) {
        next_status_ = *current_status_ptr_;

        generator_ptr_->saveCurrentRule();
        computeNextStatus(node, node_index);

        beam_ptr->insert(next_status_);
    }
}

bool BeamSearchTransducer::decodeDynamicPart2(const EdsGraph::Node& node,
                                              int node_index,
                                              int head_index,
                                              Beam* beam_ptr) {
    int size = beam_ptr->size();
    generator_ptr_->setGraph(*graph_ptr_);
    generator_ptr_->setRuleHead(rule_head_);
    generator_ptr_->setSourceIndices(source_indices_[node_index]);
    generator_ptr_->setEdgeStates(current_status_ptr_->edge_states);
    if (!g_predict_rule_head && head_index < 0) {
        // 把所有情况塞进 beam
        for (int i = -1; i < rule_head_.out_state_count; ++i) {
            generator_ptr_->setRuleHeadIndex(i, false /*from_in_edges*/);
            generator_ptr_->generateRuleTail(node, node_index);
            generator_ptr_->generateEquation(node, node_index);
            nextStepDynamic(node, node_index, beam_ptr);
        }
    } else {
        // 预测 head 或者 head 来自入边 (已知)
        generator_ptr_->setRuleHeadIndex(head_index, head_index >= 0 /*from_in_edges*/);
        generator_ptr_->generateRuleTailAuto(node, node_index);
        generator_ptr_->generateEquation(node, node_index);
        nextStepDynamic(node, node_index, beam_ptr);
    }
    return beam_ptr->size() - size > 0;
}

void BeamSearchTransducer::decodeDynamic(const EdsGraph::Node& node,
                                         int node_index,
                                         Beam* beam_ptr1,
                                         Beam*& beam_ptr2,
                                         Beam*& beam_ptr3) {
    int* node_source_indices = source_indices_[node_index];
    int head_indices[MAX_NODE_IN_EDGE_COUNT];

    beam_ptr3->clear();
    for (auto status_ptr : *beam_ptr1) {
        if (beam_ptr2->full())
            break;

        current_status_ptr_ = status_ptr;
        computeRuleHead(node, node_index);

        Beam* current_beam_ptr = beam_ptr2;
        bool need_recompute_head = false;

        int head_index = -1;
        int head_count = computeRuleHeadIndex(node_index, head_indices);

        if (head_count == 1) { // 标准情况
            head_index = head_indices[0];
        } else if (head_count > 1) { // 处理多个 head
            head_index = head_indices[0];
            current_beam_ptr = beam_ptr3;
            need_recompute_head = true;

            int head_source_index = node_source_indices[head_index];
            for (int i = 0; i < head_count; ++i) {
                int source_index = node_source_indices[head_indices[i]];
                if (nodes_rank_[source_index] < nodes_rank_[head_source_index]) {
                    head_index = head_indices[i];
                    head_source_index = source_index;
                }
            }
            for (int i = 0; i < head_count; ++i) {
                int state_index = head_indices[i];
                if (head_index == state_index)
                    continue;
                int source_index = node_source_indices[state_index];
                auto arc = MAKE_ARC(source_index, node_index);
                // 直接设置为 0 号的 EMPTY state
                status_ptr->edge_states[arc] = TokenManager::makeState(0, false);
                LOG_DEBUG(<< "Remove edge: "
                          << graph_ptr_->nodes[source_index] << " -> " << node);
            }
        } else if (head_count == 0 &&
                   node.out_edges.size() == 0 &&
                   status_ptr->top_index != -1) { // No head
            generator_ptr_->setRuleHead(rule_head_);
            generator_ptr_->generateEmptyRule();
            nextStepDynamic(node, node_index, beam_ptr3);
            continue;
        }

        if (head_count > 0) {
            int target_index = node_source_indices[head_index];
            int last_index = -1;
            while (target_index != -1 && target_index != node_index) {
                last_index = target_index;
                target_index = status_ptr->stream_target[target_index];
            }
            if (target_index != -1) { // 有环
                current_beam_ptr = beam_ptr3;
                need_recompute_head = true;

                auto arc = MAKE_ARC(last_index, node_index);
                // 直接设置为 0 号的 EMPTY state
                status_ptr->edge_states[arc] = TokenManager::makeState(0, false);
                status_ptr->stream_target[last_index] = -1;
                LOG_DEBUG(<< "Remove edge: "
                          << graph_ptr_->nodes[last_index] << " -> " << node);
            }
        }

        if (need_recompute_head) {
            computeRuleHead(node, node_index);
            head_index = computeRuleHeadIndex(node_index, nullptr);
        }

        if (!current_beam_ptr->full() &&
            !decodeDynamicPart2(node, node_index,
                                head_index, current_beam_ptr))
            LOG_WARNING(<< "Why ??? " << graph_ptr_->filename);
    }
    if (beam_ptr2->size() == 0)
        std::swap(beam_ptr2, beam_ptr3);
}

void BeamSearchTransducer::initializeStackSpecial() {
    const static int arg1_index = TokenManager::indexOfEdgeLabel("ARG1");

    int node_count = graph_ptr_->nodes.size();
    for (int i = 0; i < node_count; ++i)
        node_in_edges_count_[i] = graph_ptr_->nodes[i].in_edges.size();

    std::vector< int > nodes_stack{0};
    std::unordered_set< int > visited_nodes{0};
    while (!nodes_stack.empty()) {
        auto index = nodes_stack.back();
        nodes_stack.pop_back();
        auto& node = graph_ptr_->nodes[index];
        auto& count = node_in_edges_count_[index];
        if (count == 0) {
            nodes_stack_.push_back(index);
            count = -1;
        } else {
            for (auto source_index : node.in_edges)
                if (!visited_nodes.count(source_index)) {
                    visited_nodes.insert(source_index);
                    nodes_stack.push_back(source_index);
                }
            if (node.out_edges.size()) {
                auto& edge = node.out_edges[0];
                if (edge.label_index == arg1_index &&
                    !visited_nodes.count(edge.target_index)) {
                    visited_nodes.insert(edge.target_index);
                    nodes_stack.push_back(edge.target_index);
                }
            }
        }
    }

    for (int i = 0; i < node_count; ++i)
        if (node_in_edges_count_[i] == 0)
            nodes_stack_.push_back(i);

    std::reverse(nodes_stack_.begin(), nodes_stack_.end());
}

void BeamSearchTransducer::initializeStack() {
    int node_count = graph_ptr_->nodes.size();
    for (int i = 0; i < node_count; ++i) {
        int n = graph_ptr_->nodes[i].in_edges.size();
        node_in_edges_count_[i] = n;
        if (n == 0)
            nodes_stack_.push_back(i);
    }
}

void BeamSearchTransducer::computeNodesRank() {
    int node_count = graph_ptr_->nodes.size();
    nodes_rank_.assign(node_count, 0);

    initializeStack();

    while (!nodes_stack_.empty()) {
        int node_index = nodes_stack_.back();
        nodes_stack_.pop_back();

        for (auto& edge : graph_ptr_->nodes[node_index].out_edges) {
            auto& rank = nodes_rank_[edge.target_index];
            rank = std::max(rank, nodes_rank_[node_index] + 1);
            if (--node_in_edges_count_[edge.target_index] == 0)
                nodes_stack_.push_back(edge.target_index);
        }
    }

    assert(nodes_stack_.empty());
}

void BeamSearchTransducer::decode(bool early_update) {
    bool is_early_updated = false;
    auto beam_ptr1 = &beam1_;
    auto beam_ptr2 = &beam2_;
    auto beam_ptr3 = &beam3_;
    beam1_.clear();
    beam2_.clear();
    beam3_.clear();
    nodes_stack_.clear();
    nodes_update_.clear();

    computeNodesRank();
    int node_count = graph_ptr_->nodes.size();
    if (g_special_nodes_order)
        initializeStackSpecial();
    else
        initializeStack();

    // 第一个状态
    auto& initial_status = beam_ptr1->shrinkToOne();
    initial_status.reset(node_count);

    while (!nodes_stack_.empty()) {
        int node_index = nodes_stack_.back();
        nodes_stack_.pop_back();

        auto& node = graph_ptr_->nodes[node_index];
        // LOG_DEBUG(<< "Dequeue: " << node << '@' << node_index);
        for (auto status_ptr : *beam_ptr1) {
            current_status_ptr_ = status_ptr;
            computeRuleHead(node, node_index);

            auto& rule_tails = rule_set_ptr_->find(rule_head_);
            nextStep(node, node_index, beam_ptr2, rule_tails, false);
            if (n_state_ == TransducerState::TRANSDUCE &&
                rule_tails.size() == 0) {
                nextStep(node, node_index,
                         beam_ptr2, extended_rule_set_ptr_->find(rule_head_), true);
            }
        }

        if (beam_ptr2->size() == 0 && n_state_ == TransducerState::TRANSDUCE)
            decodeDynamic(node, node_index,
                          beam_ptr1, beam_ptr2, beam_ptr3);

        assert(beam_ptr1->size() != 0);
        // LOG_DEBUG(<< "Beam size: " << beam_ptr2->size());
        if (beam_ptr2->size() == 0) {
            if (n_state_ == TransducerState::TRAIN) {
                is_early_updated = true;
            } else {
                if (!early_update)
                    beam_ptr1->clear();
                else
                    nodes_stack_.push_back(node_index); // 为打印信息
            }
            break;
        }

        if (n_state_ == TransducerState::TRAIN) {
            nodes_update_.push_back(node_index);
        }

        beam_ptr1->clear();
        std::swap(beam_ptr1, beam_ptr2);

        for (auto& edge : node.out_edges)
            if (--node_in_edges_count_[edge.target_index] == 0) {
                nodes_stack_.push_back(edge.target_index);
                // LOG_DEBUG(<< "Enqueue: "
                //           << graph_ptr_->nodes[edge.target_index]);
            }
    }

    if (is_early_updated)
        ++n_early_update_graphs_;

    beam_ptr_ = beam_ptr1;
    if (beam_ptr_->size() == 0) {
        current_status_ptr_ = nullptr;
        LOG_DEBUG(<< "Deocde failed " << graph_ptr_->filename);
    } else {
        ++n_transduced_graphs_;
        current_status_ptr_ = &beam_ptr_->best();
        if (!nodes_update_.empty())
            update();
        // LOG_DEBUG(<< "Best: " << current_status_ptr_->score);
    }
}

void BeamSearchTransducer::printInformation() {
    LOG_INFO(<< "Total time: " << ( int )(0.000001 * running_time_)
             << "s Average time: "
             << ( int )(0.001 * running_time_ / n_training_round_)
             << "ms");
    LOG_INFO(<< "EarlyUpdate graphs:"
             << ( float )n_early_update_graphs_ / n_training_round_ << ' '
             << n_early_update_graphs_ << "/" << n_training_round_);
    LOG_INFO(<< "Correct graphs: "
             << ( float )n_correct_graphs_ / n_training_round_ << ' '
             << n_correct_graphs_ << "/" << n_training_round_);
    LOG_INFO(<< "Correct nodes: "
             << ( float )n_correct_nodes_ / n_covered_nodes_ << ' '
             << n_correct_nodes_ << "/" << n_covered_nodes_);
    LOG_INFO(<< "Covered nodes: "
             << ( float )n_covered_nodes_ / n_total_nodes_ << ' '
             << n_covered_nodes_ << "/" << n_total_nodes_);
    std::cout.flush();
}

void BeamSearchTransducer::printResultToStream(std::ostream& os,
                                               bool with_graph) {
    if (!current_status_ptr_)
        return;
    if (!nodes_stack_.empty()) {
        int node_index = nodes_stack_.back();
        if (node_index >= 0 &&
            ( std::size_t )node_index < graph_ptr_->nodes.size()) {
            auto& node = graph_ptr_->nodes[node_index];
            computeRuleHead(node, node_index);
            os << "Node: ";
            printStates(os, rule_head_.in_states);
            os << ' ' << node << " {";
            for (auto& edge : node.out_edges)
                os << TokenManager::edgeLabelAt(edge.label_index) << ' ';
            os << "}\n";
        }
    }
    os << "Score: " << current_status_ptr_->score << '\n'
       << "Top: " << current_status_ptr_->top_index << '\n';
    if (!with_graph)
        current_status_ptr_->printTree(os, *graph_ptr_);
    os << "Rules: " << std::endl;
    int i = 0;
    for (auto rule_index : current_status_ptr_->rule_choices) {
        auto& node = graph_ptr_->nodes[i++];
        os << node << "\n => predict: ";
        if (rule_index != -1)
            os << ruleAt(rule_index);
        else
            os << "none";
        os << "\n";
        if (node.rule_index != -1) {
            if (rule_index == node.rule_index)
                os << " => gold: the same\n";
            else
                os << " => gold: " << ruleAt(node.rule_index) << '\n';
        }
    }
    os.flush();

    if (with_graph)
        drawStream(graph_ptr_->filename,
                   current_status_ptr_->stream_target,
                   current_status_ptr_->edge_states,
                   *graph_ptr_);
}

void BeamSearchTransducer::train(const EdsGraph& graph,
                                 int round) {
    graph_ptr_ = &graph;
    n_training_round_ = round;
    n_total_nodes_ += graph.nodes.size();
    auto start_time = clock();

    decode(true);

    if (current_status_ptr_ &&
        isGraphCorrect(current_status_ptr_->rule_choices))
        ++n_correct_graphs_;

    running_time_ += clock() - start_time;
    if (round % OUTPUT_STEP == 0)
        printInformation();
}

bool BeamSearchTransducer::transduce(const EdsGraph& graph,
                                     bool with_detail) {
    result_.clear();
    graph_ptr_ = &graph;
    generator_ptr_->clearTemporary();

    decode(with_detail);
    if (current_status_ptr_ != nullptr) {
        if (current_status_ptr_->top_index == -1) {
            int node_count = graph.nodes.size();
            int top_index = 0, max_length = 0;
            std::vector< int > target_rank(node_count, -1);
            for (int i = 0; i < node_count; ++i) {
                if (target_rank[i] >= 0)
                    continue;
                target_rank[i] = 0;
                int target = i, length = 0, last = -1;
                while (target != -1) {
                    last = target;
                    target_rank[target] = length++;
                    target = current_status_ptr_->stream_target[target];
                }
                if (length > max_length) {
                    max_length = length;
                    top_index = last;
                }
            }
            current_status_ptr_->top_index = top_index;
        }
        auto beg = current_status_ptr_->rule_choices.cbegin();
        auto end = current_status_ptr_->rule_choices.cend();
        if (std::find(beg, end, -1) == end) {
            solve();
            return true;
        }
    }
    return false;
}

void BeamSearchTransducer::getOrUpdateNodeScore(const EdsGraph::Node& node,
                                                int rule_index,
                                                int dir, int amount) {
    WeightSecondOrder* weight = dynamic_cast< WeightSecondOrder* >(weight_);

    key3.refer(node.lemma_index, dir, rule_index);
    weight->lemma_dir_index.getOrUpdateScore(return_value_, key3,
                                             n_score_type_, amount,
                                             n_training_round_);

    key3.refer(node.label_index, dir, rule_index);
    weight->label_dir_index.getOrUpdateScore(return_value_, key3,
                                             n_score_type_, amount,
                                             n_training_round_);

    key4.refer(node.label_index, node.lemma_index, dir, rule_index);
    weight->label_lemma_dir_index.getOrUpdateScore(return_value_, key4,
                                                   n_score_type_, amount,
                                                   n_training_round_);
}

void BeamSearchTransducer::getOrUpdateNode2Score(const EdsGraph::Node& node1,
                                                 const EdsGraph::Node& node2,
                                                 int rule_index,
                                                 int dir, int amount) {
    WeightSecondOrder* weight = dynamic_cast< WeightSecondOrder* >(weight_);
    key4.refer(node1.label_index, node2.label_index, dir, rule_index);
    weight->node2_dir_index.getOrUpdateScore(return_value_, key4,
                                             n_score_type_, amount,
                                             n_training_round_);

    key4.refer(node1.lemma_index, node2.lemma_index, dir, rule_index);
    weight->lemma2_dir_index.getOrUpdateScore(return_value_, key4,
                                              n_score_type_, amount,
                                              n_training_round_);
    key6.refer(node1.label_index, node2.label_index,
               node1.lemma_index, node2.lemma_index, dir, rule_index);
    weight->node2_lemma2_dir_index.getOrUpdateScore(return_value_, key6,
                                                    n_score_type_, amount,
                                                    n_training_round_);
}

void BeamSearchTransducer::getOrUpdateScore(int node_index,
                                            int rule_index, int amount) {
    auto& node = graph_ptr_->nodes[node_index];
    WeightSecondOrder* weight = dynamic_cast< WeightSecondOrder* >(weight_);

    key2.refer(node_index == 0, rule_index);
    weight->is_top.getOrUpdateScore(return_value_, key2,
                                    n_score_type_, amount,
                                    n_training_round_);

    getOrUpdateNodeScore(node, rule_index, eHere, amount);

    for (auto& edge : node.out_edges) {
        auto& out_node = graph_ptr_->nodes[edge.target_index];
        getOrUpdateNodeScore(out_node, rule_index, eOut, amount);

        for (auto& edge2 : out_node.out_edges) { // node -> out_node -> out_node2
            auto& out_node2 = graph_ptr_->nodes[edge2.target_index];
            getOrUpdateNodeScore(out_node2,
                                 rule_index, eOutOut, amount);
            getOrUpdateNode2Score(out_node, out_node2,
                                  rule_index, eOutOut, amount);
        }

        for (auto& source_index : out_node.in_edges) {
            if (node_index == source_index)
                continue;
            // node -> out_node <- in_node
            auto& in_node = graph_ptr_->nodes[source_index];
            getOrUpdateNodeScore(in_node,
                                 rule_index, eOutIn, amount);
            getOrUpdateNode2Score(out_node, in_node,
                                  rule_index, eOutIn, amount);
        }

        key3.refer(edge.label_index, out_node.lemma_index, rule_index);
        weight->edge_lemma_index.getOrUpdateScore(return_value_, key3,
                                                  n_score_type_, amount,
                                                  n_training_round_);

        key3.refer(edge.label_index, out_node.label_index, rule_index);
        weight->edge_label_index.getOrUpdateScore(return_value_, key3,
                                                  n_score_type_, amount,
                                                  n_training_round_);

        key4.refer(edge.label_index, out_node.label_index,
                   out_node.lemma_index, rule_index);
        weight->edge_label_lemma_index.getOrUpdateScore(return_value_, key4,
                                                        n_score_type_, amount,
                                                        n_training_round_);
    }

    for (auto& source_index : node.in_edges) {
        auto& in_node = graph_ptr_->nodes[source_index];
        getOrUpdateNodeScore(in_node, rule_index, eIn, amount);

        for (auto& edge2 : in_node.out_edges) {
            if (node_index == edge2.target_index)
                continue;
            // out_node <- in_node -> node
            auto& out_node = graph_ptr_->nodes[edge2.target_index];
            getOrUpdateNodeScore(out_node,
                                 rule_index, eInOut, amount);
            getOrUpdateNode2Score(out_node, in_node,
                                  rule_index, eInOut, amount);
        }

        for (auto& source_index : in_node.in_edges) {
            // in_node2 -> in_node -> node
            auto& in_node2 = graph_ptr_->nodes[source_index];
            getOrUpdateNodeScore(in_node2, rule_index, eInIn, amount);
            getOrUpdateNode2Score(in_node, in_node2,
                                  rule_index, eInIn, amount);
        }

        auto iter = graph_ptr_->edges.find(MAKE_ARC(source_index,
                                                    node_index));

        assert(iter != graph_ptr_->edges.end());

        int edge_label_index = iter->second;

        key3.refer(in_node.label_index, edge_label_index, rule_index);
        weight->label_edge_index.getOrUpdateScore(return_value_, key3,
                                                  n_score_type_, amount,
                                                  n_training_round_);

        key3.refer(in_node.lemma_index, edge_label_index, rule_index);
        weight->lemma_edge_index.getOrUpdateScore(return_value_, key3,
                                                  n_score_type_, amount,
                                                  n_training_round_);

        key4.refer(in_node.label_index, in_node.lemma_index,
                   edge_label_index, rule_index);
        weight->label_lemma_edge_index.getOrUpdateScore(return_value_, key4,
                                                        n_score_type_, amount,
                                                        n_training_round_);
    }
}
}
