#ifndef DAG_TRANSDUCER_BASE_H
#define DAG_TRANSDUCER_BASE_H

#include "include/basic.h"
#include "include/perceptron/score.h"
#include "weight_base.h"
#include "macros_base.h"

class RuleSet;
class EdsGraph;

enum TransducerState {
    TRAIN = 1,
    TRANSDUCE,
    GOLDTEST
};

class DAGTransducerBase {
protected:
    // pointer to current graph when decoding
    const EdsGraph* graph_ptr_;
    WeightBase* weight_;
    TransducerState n_state_;
    int n_score_type_;
    int n_training_round_;
    int n_correct_nodes_;
    int n_covered_nodes_;
    int n_total_nodes_;
    int n_early_update_graphs_;
    int n_transduced_graphs_;
    int n_correct_graphs_;

    std::string result_;

public:
    DAGTransducerBase(TransducerState n_state)
    : n_state_(n_state),
      n_score_type_(n_state == TransducerState::TRAIN
                        ? ScoreType::eNonAverage
                        : ScoreType::eAverage),
      n_training_round_(0),
      n_correct_nodes_(0),
      n_covered_nodes_(0),
      n_total_nodes_(0),
      n_early_update_graphs_(0),
      n_transduced_graphs_(0),
      n_correct_graphs_(0) {}

    void clear() {
        n_training_round_ = 0;
        n_correct_nodes_ = 0;
        n_covered_nodes_ = 0;
        n_total_nodes_ = 0;
        n_early_update_graphs_ = 0;
        n_transduced_graphs_ = 0;
        n_correct_graphs_ = 0;
    }

    virtual ~DAGTransducerBase() = default;

    virtual void finishTraining() {
        weight_->computeAverageFeatureWeights(n_training_round_);
        weight_->saveScores();
    }

public:
    virtual bool hasResultComputed() = 0;
    virtual int targetOrSourceIndex(int node_index, int var_major) = 0;
    virtual int currentTopIndex() = 0;
    virtual int currentRuleChoice(int node_index) = 0;
    virtual int currentStreamTarget(int node_index) = 0;

    virtual const RuleSet::Rule& ruleAt(int rule_index) const = 0;

    const EdsGraph& currentGraph() {
        return *graph_ptr_;
    }
    const std::string& result() { return result_; }

    bool isGraphCorrect(std::vector< int > rule_choices) {
        int i = 0;
        for (auto& node : graph_ptr_->nodes)
            if (node.rule_index != rule_choices[i++])
                return false;
        return true;
    }
};

#endif /* DAG_TRANSDUCER_BASE_H */

// Local Variables:
// mode: c++
// End:
