#ifndef BASELINE_TRANSDUCER_H
#define BASELINE_TRANSDUCER_H

#include <vector>

#include "common/transducer/dag_transducer_base.h"
#include "common/transducer/beam.h"
#include "common/transducer/solver.h"
#include "baseline_macros.h"
#include "baseline_status.h"


namespace baseline {

class BeamSearchTransducer : public DAGTransducerBase {
protected:
    EquationSolverBase solver; // equations solver
    TwoInt key2;               // two int key
    ThreeInt key3;
    FourInt key4;

    // pointer to current status when decoding
    Status* current_status_ptr_;
    // current rule set
    const RuleSet& rule_set_;
    // next status when decodeing
    Status next_status_;
    // gold status of current graph
    Status gold_status_;

    double running_time_;

    int total_error_count_;
    tscore return_value_;
    // whether current rule is a 'top' rule ($G)
    bool is_current_rule_top_;
    // current stream target node index
    int stream_target_index_;
    int in_state_counts_[MAX_GRAPH_NODE_COUNT];
    int source_indices_[MAX_GRAPH_NODE_COUNT][MAX_NODE_IN_EDGE_COUNT];
    int target_indices_[MAX_GRAPH_NODE_COUNT][MAX_NODE_OUT_EDGE_COUNT];
    int node_in_edges_count_[MAX_GRAPH_NODE_COUNT];
    int local_edge_states_[MAX_NODE_OUT_EDGE_COUNT];
    std::vector< int > nodes_stack_;
    std::vector< int > nodes_update_;

    RuleSet::Head rule_head_;
    const RuleSet::Tail* rule_tail_ptr_;

    using Beam = SimpleBeam< Status >;
    Beam beam1_;
    Beam beam2_;
    Beam* beam_ptr_;

    void update();

    void computeSourceIndices(const EdsGraph::Node& node, int node_index);
    bool computeTargetIndices(const EdsGraph::Node& node, int node_index);

    bool computeRuleTail(const EdsGraph::Node& node, int node_index);
    void computeRuleHead(const EdsGraph::Node& node, int node_index);
    void computeNextStatus(const EdsGraph::Node& node, int node_index);
    void computeGoldStatus(const EdsGraph::Node& node, int node_index);

    void getOrUpdateNodeScore(const EdsGraph::Node& node,
                              int rule_index, int dir, int amount);

    virtual void getOrUpdateScore(int node_index, int rule_index, int amount);

    virtual void decode(bool early_update);

    void solve();

public:
    BeamSearchTransducer(const std::string& feature_input,
                         const std::string& feature_output,
                         const RuleSet& rule_set,
                         TransducerState n_state,
                         int beam_size);
    ~BeamSearchTransducer();

    bool hasResultComputed() override {
        return current_status_ptr_ != nullptr;
    }

    int currentTopIndex() override {
        return current_status_ptr_->top_index;
    }

    int currentRuleChoice(int node_index) override {
        return current_status_ptr_->rule_choices[node_index];
    }

    int currentStreamTarget(int node_index) override {
        return current_status_ptr_->stream_target[node_index];
    }

    int targetOrSourceIndex(int node_index, int var_major) override {
        int in_state_count = in_state_counts_[node_index];
        // 数据流的目标结点
        return var_major < in_state_count
                   ? source_indices_[node_index][var_major]
                   : target_indices_[node_index][var_major - in_state_count];
    }

    const RuleSet::Rule& ruleAt(int rule_index) const override {
        return rule_set_.ruleAt(rule_index);
    }

    void printInformation();

    void printResultToStream(std::ostream& os, bool with_graph = false);

    void train(const EdsGraph& graph, int round);

    bool transduce(const EdsGraph& graph, bool with_detail = false);
};
}

#endif /* BASELINE_TRANSDUCER_H */

// Local Variables:
// mode: c++
// End:
