#include "common/transducer/weight_second_order.h"
#include "data_driven_transducer.h"
#include "data_driven_generator_weight.h"

#include "io/token_manager.h"
#include "io/rule.h"
#include "io/utility.h"
#include "io/edsgraph.h"

#define IS_HEAD_STATE(state_index) (( state_index )&1)
#define ORIGINAL_STATE(state_index) ((state) >> 1)

namespace data_driven {

EdsGraph::Node g_lemma_node{-1, -1, -1, -1, "#"};

tscore DynamicRuleGenerator::stateNodeScore(const EdsGraph::Node& node,
                                            int state_index,
                                            int dir,
                                            int amount) {
    tscore ret = 0;
    DynamicWeight* weight = dynamic_cast< DynamicWeight* >(dynamic_weight_);

    key3.refer(node.lemma_index, dir, state_index);
    weight->lemma_dir_index.getOrUpdateScore(ret, key3,
                                             n_score_type_, amount,
                                             n_round_);

    key3.refer(node.label_index, dir, state_index);
    weight->label_dir_index.getOrUpdateScore(ret, key3,
                                             n_score_type_, amount,
                                             n_round_);

    key4.refer(node.label_index, node.lemma_index, dir, state_index);
    weight->label_lemma_dir_index.getOrUpdateScore(ret, key4,
                                                   n_score_type_, amount,
                                                   n_round_);
    return ret;
}

tscore DynamicRuleGenerator::stateEdgeScore(int source_index,
                                            int target_index,
                                            int edge_label_index,
                                            int state_index,
                                            int dir,
                                            int amount) {
    tscore ret = 0;
    DynamicWeight* weight = dynamic_cast< DynamicWeight* >(dynamic_weight_);
    auto& source = graph_ptr_->nodes[source_index];
    auto& target = graph_ptr_->nodes[target_index];

    key4.refer(source.label_index, target.label_index, dir, state_index);
    weight->node2_dir_index.getOrUpdateScore(ret, key4,
                                             n_score_type_, amount,
                                             n_round_);
    key5.refer(source.label_index, target.label_index,
               edge_label_index, dir, state_index);
    weight->node2_edge_dir_index.getOrUpdateScore(ret, key5,
                                                  n_score_type_, amount,
                                                  n_round_);
    key4.refer(source.lemma_index, target.lemma_index, dir, state_index);
    weight->lemma2_dir_index.getOrUpdateScore(ret, key4,
                                              n_score_type_, amount,
                                              n_round_);
    key5.refer(source.lemma_index, target.lemma_index,
               edge_label_index, dir, state_index);
    weight->lemma2_edge_dir_index.getOrUpdateScore(ret, key5,
                                                   n_score_type_, amount,
                                                   n_round_);
    if (dir == eHere) {
        key3.refer(source.label_index, target.label_index, state_index);
        weight->node2_index.getOrUpdateScore(ret, key3,
                                             n_score_type_, amount,
                                             n_round_);
        key3.refer(source.lemma_index, target.lemma_index, state_index);
        weight->lemma2_index.getOrUpdateScore(ret, key3,
                                              n_score_type_, amount,
                                              n_round_);
        key5.refer(source.label_index, target.label_index,
                   source.lemma_index, target.lemma_index, state_index);
        weight->node2_lemma2_index.getOrUpdateScore(ret, key5,
                                                    n_score_type_, amount,
                                                    n_round_);
    } else if (dir == eIn) {
        auto arc = MAKE_ARC(source_index, target_index);
        auto iter = edge_states_ptr_->find(arc);
        assert(iter != edge_states_ptr_->end());
        auto state = iter->second;
        int in_state_index = MAKE_STATE_INDEX(IS_EDGE_REVERSED(state),
                                              STATE_INDEX(GET_STATE(state)));
        key4.refer(-1, target.label_index, in_state_index, state_index);
        weight->node2_index2.getOrUpdateScore(ret, key4,
                                              n_score_type_, amount,
                                              n_round_);
        key4.refer1(source.label_index);
        weight->node2_index2.getOrUpdateScore(ret, key4,
                                              n_score_type_, amount,
                                              n_round_);
        key4.refer(-1, target.lemma_index, in_state_index, state_index);
        weight->node2_index2.getOrUpdateScore(ret, key4,
                                              n_score_type_, amount,
                                              n_round_);
        key4.refer1(source.lemma_index);
        weight->node2_index2.getOrUpdateScore(ret, key4,
                                              n_score_type_, amount,
                                              n_round_);
        key2.refer(in_state_index, state_index);
        weight->index2.getOrUpdateScore(ret, key2,
                                        n_score_type_, amount,
                                        n_round_);
    }
    return ret;
}

tscore DynamicRuleGenerator::stateScore(const EdsGraph::Node& node,
                                        int node_index,
                                        int edge_index,
                                        int state_index,
                                        int amount) {
#ifdef DECODER_CHECK
    auto arc = MAKE_ARC(node_index,
                        node.out_edges[edge_index].target_index);
    auto iter = edge_states_ptr_->find(arc);

    assert(iter != edge_states_ptr_->end());

    bool is_head = !head_from_in_edges_ && head_index_ == edge_index;
    auto state = STATE_INDEX(GET_STATE(iter->second));
    if (MAKE_STATE_INDEX(is_head, state) == state_index)
        return 10;
    return 0;
#else
    tscore ret = 0;
    auto& current_edge = node.out_edges[edge_index];
    auto& target_node = graph_ptr_->nodes[current_edge.target_index];

    ret += stateNodeScore(node, state_index, eSrc, amount);
    ret += stateNodeScore(target_node, state_index, eTar, amount);

    for (auto& edge : node.out_edges)
        if (edge.target_index != current_edge.target_index) {
            ret += stateNodeScore(graph_ptr_->nodes[edge.target_index],
                                  state_index, eSrcTar, amount);
            ret += stateEdgeScore(node_index, edge.target_index,
                                  edge.label_index,
                                  state_index, eOut, amount);
        } else {
            ret += stateEdgeScore(node_index, edge.target_index,
                                  edge.label_index,
                                  state_index, eHere, amount);
        }

    for (auto source_index : node.in_edges) {
        ret += stateNodeScore(graph_ptr_->nodes[source_index],
                              state_index, eSrcSrc, amount);
        auto arc = MAKE_ARC(source_index, node_index);
        auto iter = graph_ptr_->edges.find(arc);
        assert(iter != graph_ptr_->edges.end());
        ret += stateEdgeScore(source_index, node_index,
                              iter->second,
                              state_index, eIn, amount);
    }

    for (auto& edge : target_node.out_edges)
        ret += stateNodeScore(graph_ptr_->nodes[edge.target_index],
                              state_index, eTarTar, amount);

    for (auto source_index : target_node.in_edges)
        if (source_index != node_index)
            ret += stateNodeScore(graph_ptr_->nodes[source_index],
                                  state_index, eTarSrc, amount);
    return ret;
#endif
}

tscore DynamicRuleGenerator::transitionScore(const EdsGraph::Node& node,
                                             int node_index,
                                             int edge_index,
                                             int state_index1,
                                             int state_index2,
                                             int amount) {
    if (!g_generator_use_transition)
        return 0;
    if (edge_index == 0) {
        LOG_ERROR(<< "No tansition at first edge");
        return 0;
    }
    tscore ret = 0;

    DynamicWeight* weight = dynamic_cast< DynamicWeight* >(dynamic_weight_);
    auto& edge1 = node.out_edges[edge_index];
    auto& edge2 = node.out_edges[edge_index - 1];
    auto& target1 = graph_ptr_->nodes[edge1.target_index];
    auto& target2 = graph_ptr_->nodes[edge2.target_index];

    key4.refer(-1, -1, state_index1, state_index2);
    weight->transition_edge2_index2.getOrUpdateScore(ret, key4,
                                                     n_score_type_, amount,
                                                     n_round_);
    key4.refer(edge1.label_index, edge2.label_index,
               state_index1, state_index2);
    weight->transition_edge2_index2.getOrUpdateScore(ret, key4,
                                                     n_score_type_, amount,
                                                     n_round_);

    key5.refer(-1, node.label_index, -1,
               state_index1, state_index2);
    weight->transition_node3_index2.getOrUpdateScore(ret, key5,
                                                     n_score_type_, amount,
                                                     n_round_);
    key5.refer(target1.label_index, -1, target2.label_index,
               state_index1, state_index2);
    weight->transition_node3_index2.getOrUpdateScore(ret, key5,
                                                     n_score_type_, amount,
                                                     n_round_);

    key5.refer(target1.label_index, node.label_index, target2.label_index,
               state_index1, state_index2);
    weight->transition_node3_index2.getOrUpdateScore(ret, key5,
                                                     n_score_type_, amount,
                                                     n_round_);

    key5.refer(-1, node.lemma_index, -1,
               state_index1, state_index2);
    weight->transition_lemma3_index2.getOrUpdateScore(ret, key5,
                                                      n_score_type_, amount,
                                                      n_round_);
    key5.refer(target1.lemma_index, -1, target2.lemma_index,
               state_index1, state_index2);
    weight->transition_lemma3_index2.getOrUpdateScore(ret, key5,
                                                      n_score_type_, amount,
                                                      n_round_);

    key5.refer(target1.lemma_index, node.lemma_index, target2.lemma_index,
               state_index1, state_index2);
    weight->transition_lemma3_index2.getOrUpdateScore(ret, key5,
                                                      n_score_type_, amount,
                                                      n_round_);

    return ret;
}

tscore DynamicRuleGenerator::variableScore(const EdsGraph::Node& node,
                                           int node_index,
                                           int var_major,
                                           int var_minor,
                                           int amount) {
    int predict_var = var_major >= 0
                          ? MAKE_VAR(var_major, var_minor)
                          : LEMMA_VAR;
#ifdef DECODER_CHECK
    auto& rule = rule_set_ptr_->ruleAt(node.rule_index);

    assert(rule.tail.equations.size() == 1);
    assert(rule.head.in_states == rule_head_ptr_->in_states);

    auto& equation = rule.tail.equations[0];

    auto pos = std::find(equation.begin(), equation.end(),
                         predict_var) -
               equation.begin();
    return pos;
#else
    tscore ret = 0;
    int dir;
    int in_state_count = rule_head_ptr_->in_states.size();
    DynamicWeight* weight = dynamic_cast< DynamicWeight* >(dynamic_weight_);
    const EdsGraph::Node* other_node_ptr;

    if (var_major >= in_state_count) { // Out edges
        auto target_index = target_indices_[var_major - in_state_count];
        other_node_ptr = &graph_ptr_->nodes[target_index];
        dir = eOut;
    } else if (var_major >= 0) { // In edges
        auto source_index = source_indices_[var_major];
        other_node_ptr = &graph_ptr_->nodes[source_index];
        dir = eIn;
    } else { // Lemma
        other_node_ptr = &g_lemma_node;
        dir = eHere;
    }

    key2.refer(predict_var, dir);
    weight->var_dir_index.getOrUpdateScore(ret, key2,
                                           n_score_type_, amount,
                                           n_round_);

    key4.refer(node.label_index, -1, dir, predict_var);
    weight->var_node_lemma_dir_index.getOrUpdateScore(ret, key4,
                                                      n_score_type_, amount,
                                                      n_round_);
    key4.refer(-1, node.lemma_index, dir, predict_var);
    weight->var_node_lemma_dir_index.getOrUpdateScore(ret, key4,
                                                      n_score_type_, amount,
                                                      n_round_);

    key4.refer(node.label_index, node.lemma_index, dir, predict_var);
    weight->var_node_lemma_dir_index.getOrUpdateScore(ret, key4,
                                                      n_score_type_, amount,
                                                      n_round_);

    key4.refer(node.label_index, other_node_ptr->label_index,
               dir, predict_var);
    weight->var_node2_dir_index.getOrUpdateScore(ret, key4,
                                                 n_score_type_, amount,
                                                 n_round_);
    key4.refer(node.lemma_index, other_node_ptr->lemma_index,
               dir, predict_var);
    weight->var_lemma2_dir_index.getOrUpdateScore(ret, key4,
                                                  n_score_type_, amount,
                                                  n_round_);
    return ret;
#endif
}
}
