#include <fstream>

#include "include/basic.h"
#include "io/token_manager.h"
#include "data_driven_generator_weight.h"

namespace data_driven {
DynamicWeight::DynamicWeight(const std::string& input_file,
                             const std::string& output_file)
: WeightBase(input_file, output_file),
  // 1 阶相邻点
  lemma_dir_index(" lemma_dir_index"),
  label_dir_index(" label_dir_index"),
  label_lemma_dir_index(" label_lemma_dir_index"),
  // 边
  node2_dir_index("node2_dir_index"),
  lemma2_dir_index("lemma2_dir_index"),
  node2_edge_dir_index("node2_edge_dir_index"),
  lemma2_edge_dir_index("lemma2_edge_dir_index"),

  node2_index("node2_index"),
  lemma2_index("lemma2_index"),
  node2_lemma2_index("node2_lemma2_index"),

  node2_index2("node2_index2"),
  lemma2_index2("lemma2_index2"),
  index2("index2"),
  // -------------------- transition --------------------
  transition_edge2_index2(" transition_edge_index_edge_index"),
  transition_node3_index2(" transition_node3_index2"),
  transition_lemma3_index2(" transition_lemma3_index2"),
  // -------------------- variable --------------------
  var_dir_index(" var_dir_index"),
  var_node_lemma_dir_index(" var_node_lemma_dir_index"),
  var_node2_dir_index(" var_node2_dir_index"),
  var_lemma2_dir_index(" var_lemma2_dir_index") {
    WeightBase::loadScores();
    LOG_INFO(<< "Features load complete");
}

void DynamicWeight::loadScores(std::istream& is) {
    // 1 阶相邻点
    is >> lemma_dir_index;
    is >> label_dir_index;
    is >> label_lemma_dir_index;
    // 边
    is >> node2_dir_index;
    is >> lemma2_dir_index;
    is >> node2_edge_dir_index;
    is >> lemma2_edge_dir_index;

    is >> node2_index;
    is >> lemma2_index;
    is >> node2_lemma2_index;

    is >> node2_index2;
    is >> lemma2_index2;
    is >> index2;

    // -------------------- transition --------------------
    is >> transition_edge2_index2;
    is >> transition_node3_index2;
    is >> transition_lemma3_index2;

    // -------------------- variable --------------------
    is >> var_dir_index;
    is >> var_node_lemma_dir_index;
    is >> var_node2_dir_index;
    is >> var_lemma2_dir_index;
}

void DynamicWeight::saveScores(std::ostream& os) const {
    // 1 阶相邻点
    os << lemma_dir_index;
    os << label_dir_index;
    os << label_lemma_dir_index;
    // 边
    os << node2_dir_index;
    os << lemma2_dir_index;
    os << node2_edge_dir_index;
    os << lemma2_edge_dir_index;

    os << node2_index;
    os << lemma2_index;
    os << node2_lemma2_index;

    os << node2_index2;
    os << lemma2_index2;
    os << index2;

    // -------------------- transition --------------------
    os << transition_edge2_index2;
    os << transition_node3_index2;
    os << transition_lemma3_index2;

    // -------------------- variable --------------------
    os << var_dir_index;
    os << var_node_lemma_dir_index;
    os << var_node2_dir_index;
    os << var_lemma2_dir_index;
}

void DynamicWeight::computeAverageFeatureWeights(const int& round) {
    // 1 阶相邻点
    lemma_dir_index.computeAverage(round);
    label_dir_index.computeAverage(round);
    label_lemma_dir_index.computeAverage(round);
    // 边
    node2_dir_index.computeAverage(round);
    lemma2_dir_index.computeAverage(round);
    node2_edge_dir_index.computeAverage(round);
    lemma2_edge_dir_index.computeAverage(round);

    node2_index.computeAverage(round);
    lemma2_index.computeAverage(round);
    node2_lemma2_index.computeAverage(round);

    node2_index2.computeAverage(round);
    lemma2_index2.computeAverage(round);
    index2.computeAverage(round);

    // -------------------- transition --------------------
    transition_edge2_index2.computeAverage(round);
    transition_node3_index2.computeAverage(round);
    transition_lemma3_index2.computeAverage(round);

    // -------------------- variable --------------------
    var_dir_index.computeAverage(round);
    var_node_lemma_dir_index.computeAverage(round);
    var_node2_dir_index.computeAverage(round);
    var_lemma2_dir_index.computeAverage(round);
}
}
