#ifndef DATA_DRIVEN_TRANSDUCER_H
#define DATA_DRIVEN_TRANSDUCER_H

#include "../baseline/baseline_transducer.h"
#include "data_driven_rule_generator.h"

#define RULESET_LEVEL_OFFSET 10000000

namespace data_driven {

using baseline::Status;
extern bool g_predict_rule_head;
extern bool g_special_nodes_order;

class BeamSearchTransducer : public DAGTransducerBase {
private:
    EquationSolverBase solver; // equations solver
    DynamicRuleGenerator* generator_ptr_;

    TwoInt key2; // two int key
    ThreeInt key3;
    FourInt key4;
    SixInt key6;

    // pointer to current status when decoding
    Status* current_status_ptr_;
    // rule sets
    const RuleSet* rule_set_ptr_;
    const RuleSet* extended_rule_set_ptr_;
    // prototype index
    std::vector< int > prototype_indices_;
    // next status when decodeing
    Status next_status_;

    double running_time_;

    tscore return_value_;
    // whether current rule is a 'top' rule ($G)
    bool is_current_rule_top_;
    // current stream target node index
    int stream_target_index_;
    int in_state_counts_[MAX_GRAPH_NODE_COUNT];
    int source_indices_[MAX_GRAPH_NODE_COUNT][MAX_NODE_IN_EDGE_COUNT];
    int target_indices_[MAX_GRAPH_NODE_COUNT][MAX_NODE_OUT_EDGE_COUNT];
    int node_in_edges_count_[MAX_GRAPH_NODE_COUNT];
    int local_edge_states_[MAX_NODE_OUT_EDGE_COUNT];
    std::vector< int > nodes_stack_;
    std::vector< int > nodes_update_;

    RuleSet::Head rule_head_;
    const RuleSet::Tail* rule_tail_ptr_;

    using Beam = SimpleBeam< Status >;
    Beam beam1_;
    Beam beam2_;
    Beam beam3_;
    Beam* beam_ptr_;

    // Use for record right head index when exisiting multiple
    std::vector< int > nodes_rank_;

private:
    void update();

    void computeSourceIndices(const EdsGraph::Node& node, int node_index);
    bool computeTargetIndices(const EdsGraph::Node& node, int node_index);

    void computeAllIndices();

    bool computeRuleTail(const EdsGraph::Node& node, int node_index);
    void computeRuleHead(const EdsGraph::Node& node, int node_index);
    void computeNextStatus(const EdsGraph::Node& node, int node_index);

    void getOrUpdateNodeScore(const EdsGraph::Node& node,
                              int rule_index, int dir, int amount);

    void getOrUpdateNode2Score(const EdsGraph::Node& node1,
                               const EdsGraph::Node& node2,
                               int rule_index, int dir, int amount);

    virtual void getOrUpdateScore(int node_index, int rule_index, int amount);

    int computeRuleHeadIndex(int node_index, int* head_indices);
    void computeNodesRank();

    void initializeStack();
    void initializeStackSpecial();

    void nextStep(const EdsGraph::Node& node,
                  int node_index,
                  Beam* beam_ptr,
                  const RuleSet::Tails& rule_tails,
                  bool are_rules_extended);

    void nextStepDynamic(const EdsGraph::Node& node,
                         int node_index,
                         Beam* beam_ptr);

    bool decodeDynamicPart2(const EdsGraph::Node& node,
                            int node_index,
                            int head_index,
                            Beam* beam_ptr);

    void decodeDynamic(const EdsGraph::Node& node,
                       int node_index,
                       Beam* beam_ptr1,
                       Beam*& beam_ptr2,
                       Beam*& beam_ptr3);

    void decode(bool early_update);

    void solve();

public:
    BeamSearchTransducer(const std::string& feature_input,
                         const std::string& feature_output,
                         TransducerState n_state,
                         int beam_size);

    ~BeamSearchTransducer();

    bool hasResultComputed() override {
        return current_status_ptr_ != nullptr;
    }

    int currentTopIndex() override {
        return current_status_ptr_->top_index;
    }

    int currentRuleChoice(int node_index) override {
        return current_status_ptr_->rule_choices[node_index];
    }

    int currentStreamTarget(int node_index) override {
        return current_status_ptr_->stream_target[node_index];
    }

    int targetOrSourceIndex(int node_index, int var_major) override {
        int in_state_count = in_state_counts_[node_index];
        // 数据流的目标结点
        return var_major < in_state_count
                   ? source_indices_[node_index][var_major]
                   : target_indices_[node_index][var_major - in_state_count];
    }

    const RuleSet::Rule& ruleAt(int index) const override {
        if (index > RULESET_LEVEL_OFFSET)
            return extended_rule_set_ptr_->ruleAt(index - RULESET_LEVEL_OFFSET);
        else if (index < 0) {
            assert(generator_ptr_ != nullptr);
            return generator_ptr_->ruleAt(-index - 2);
        }
        return rule_set_ptr_->ruleAt(index);
    }

    std::vector< int >& prototypeIndices() {
        return prototype_indices_;
    }

    int prototypeIndex(int rule_index) const {
        return prototype_indices_[rule_index - RULESET_LEVEL_OFFSET];
    }

    void setRuleSet(RuleSet& rule_set, bool extended = false) {
        if (extended) {
            for (auto& rule : rule_set.rules())
                rule.tail.rule_index += RULESET_LEVEL_OFFSET;
            extended_rule_set_ptr_ = &rule_set;
        } else
            rule_set_ptr_ = &rule_set;
    }

    void setGenerator(DynamicRuleGenerator* generator_ptr) {
        generator_ptr_ = generator_ptr;
    }

    void printInformation();
    void printResultToStream(std::ostream& os, bool with_graph = false);

    void train(const EdsGraph& graph, int round);

    bool transduce(const EdsGraph& graph, bool with_detail = false);

    void finishTraining() override {
        weight_->computeAverageFeatureWeights(n_training_round_);
        weight_->saveScores();
    }
};
}


#endif /* DATA_DRIVEN_TRANSDUCER_H */

// Local Variables:
// mode: c++
// End:
