#ifndef DATA_DRIVEN_RULE_GENERATOR_H
#define DATA_DRIVEN_RULE_GENERATOR_H

#include "../baseline/baseline_transducer.h"

#define MAX_STATES_COUNT 500
#define MAKE_STATE_INDEX(is_head, state) \
    (((state) << 1) | ((is_head) ? 1 : 0))

namespace data_driven {

using baseline::Status;

extern bool g_generator_use_transition;

class DynamicRuleGenerator {
private:
    TransducerState n_state_;
    int n_score_type_;
    int n_round_;

    int n_total_nodes_ = 0;
    int n_total_states_ = 0;
    int n_equation_generated_ = 0;
    int n_state_correct_ = 0;
    int n_equation_correct_ = 0;

    TwoInt key2; // two int key
    ThreeInt key3;
    FourInt key4;
    FiveInt key5;
    SixInt key6;

    const RuleSet* rule_set_ptr_;
    const EdsGraph* graph_ptr_;

    // weight use to generate dynamic rules
    WeightBase* dynamic_weight_;

    // - head_index_ == -1 :: 没有 head
    // - head_from_in_edges_ = true :: head 在入边中, 且 head_index_
    //      代表在 rule_head_ptr_->in_states 中的索引
    // - head_from_in_edges_ = false :: head 在出边中, 且 head_index_
    //      代表在 node.out_edges 中的索引
    int head_index_;
    bool head_from_in_edges_;
    tscore states_score_;

    std::vector< RuleSet::Rule > temporary_rules_;

    const RuleSet::Head* rule_head_ptr_;
    RuleSet::Tail rule_tail_; // used for generate rules
    EdgeStateMap* edge_states_ptr_;
    int* source_indices_;
    int target_indices_[MAX_NODE_OUT_EDGE_COUNT];
    int states_to_edges_[MAX_NODE_OUT_EDGE_COUNT];
    int edges_to_states_[MAX_NODE_OUT_EDGE_COUNT];
    int gold_states_[MAX_NODE_OUT_EDGE_COUNT];

    using StateBeam = SimpleBeam< std::pair< tscore, int > >;
    StateBeam state_beam1_;
    StateBeam state_beam2_;

    int vertibi_bp_[MAX_NODE_OUT_EDGE_COUNT][MAX_STATES_COUNT];

private:
    tscore stateNodeScore(const EdsGraph::Node& node,
                          int state_index, int dir, int amount);

    tscore stateEdgeScore(int source_index, int target_index,
                          int edge_label_index,
                          int state_index, int dir, int amount);

    tscore stateScore(const EdsGraph::Node& node, int node_index,
                      int edge_index,
                      int state_index,
                      int amount);
    tscore transitionScore(const EdsGraph::Node& node, int node_index,
                           int edge_index,
                           int last_state_index,
                           int state_index,
                           int amount);
    tscore variableScore(const EdsGraph::Node& node, int node_index,
                         int var_major,
                         int var_minor,
                         int amount);

    void updateRuleScore(const EdsGraph::Node& node, int node_index);
    void updateEquationScore(const EdsGraph::Node& node, int node_index);

    // bool decodeSimple(const EdsGraph::Node& node, int node_index);
    void decodeVertibi(const EdsGraph::Node& node, int node_index);

    void computeHeadIndex(const EdsGraph::Node& node,
                          bool convert_to_edge_index);

    bool computeSourceIndices(const EdsGraph::Node& node,
                              int node_index);
    bool computeTargetIndices(const EdsGraph::Node& node,
                              int node_index,
                              const RuleSet::Tail& rule_tail);

public:
    DynamicRuleGenerator(const std::string& feature_input,
                         const std::string& feature_output,
                         TransducerState n_state,
                         int generator_k_best);
    ~DynamicRuleGenerator();

    const RuleSet::Tail* ruleTailPtr() const {
        return &rule_tail_;
    }
    const RuleSet::Rule& ruleAt(int index) const {
        return temporary_rules_[index];
    }

    void setSourceIndices(int* source_indices) {
        source_indices_ = source_indices;
    }

    void setGraph(const EdsGraph& graph) {
        graph_ptr_ = &graph;
    }
    void setRuleHeadIndex(int index, bool from_in_edges) {
        head_index_ = index;
        head_from_in_edges_ = from_in_edges;
    }
    void setRuleHead(const RuleSet::Head& rule_head) {
        rule_head_ptr_ = &rule_head;
    }
    void setEdgeStates(std::unordered_map< EdsGraph::Arc,
                                           RuleSet::State >& edge_states) {
        edge_states_ptr_ = &edge_states;
    }

    void generateRuleTail(const EdsGraph::Node& node, int node_index);
    void generateRuleTailAuto(const EdsGraph::Node& node, int node_index);

    void generateEquation(const EdsGraph::Node& node, int node_index);

    void generateEmptyRule() {
        rule_tail_.out_states.clear();
        rule_tail_.equations.clear();
    }

    void clearTemporary() {
        temporary_rules_.clear();
    }

    void saveCurrentRule() {
        rule_tail_.rule_index = -2 - temporary_rules_.size();
        temporary_rules_.push_back({*rule_head_ptr_, std::move(rule_tail_)});
    }

    void printInformation();

    void trainRuleGenerator(const EdsGraph& graph,
                            const RuleSet& rule_set,
                            int round);

    void finishTraining() {
        dynamic_weight_->computeAverageFeatureWeights(n_round_);
        dynamic_weight_->saveScores();
    }
};
}


#endif /* DATA_DRIVEN_RULE_GENERATOR_H */

// Local Variables:
// mode: c++
// End:
