#include <ctime>
#include <memory>
#include <fstream>
#include <iostream>
#include <algorithm>
#include <cstdlib>
#include <numeric>

#include "io/edsgraph.h"
#include "io/rule.h"
#include "io/command_line.h"

#include "data_driven_transducer.h"
#include "data_driven_run.h"

namespace data_driven {

int g_beam_size = 128;
int g_generator_k_best = 32;

template < typename T >
void printStatistics(const std::vector< T >& vals, const char* message) {
    double sum = std::accumulate(vals.begin(), vals.end(), static_cast< T >(0));
    auto mean = sum / vals.size() / 1000000;

    LOG_INFO(<< "Mean for " << message << " graphs: " << mean);
    LOG_INFO(<< "Max for " << message << " graphs: "
             << (( double )*std::max_element(vals.begin(), vals.end())) / 1000000);
}

void loadPrototypeIndices(const std::string& filename,
                          std::vector< int >& prototype_indeices) {
    std::ifstream is(filename);
    if (!is) {
        LOG_ERROR(<< "Can't find file " << filename);
        return;
    }
    for (auto& index : prototype_indeices)
        is >> index;
}

void Run::initialize() const {
    auto& vm = commandLineVariables();
    if (vm.count("beam_size"))
        g_beam_size = vm["beam_size"].as< int >();
    if (vm.count("generator_k_best"))
        g_generator_k_best = vm["generator_k_best"].as< int >();
    if (vm.count("generator_use_transition"))
        g_generator_use_transition =
            vm["generator_use_transition"].as< int >();
    if (vm.count("predict_rule_head"))
        g_predict_rule_head = vm["predict_rule_head"].as< int >();
    if (vm.count("special_nodes_order"))
        g_special_nodes_order = vm["special_nodes_order"].as< int >();
}

void Run::train(const std::string& input_file,
                const std::string& rules_file,
                const std::string& feature_input,
                const std::string& feature_output) const {

    initialize();

    auto& vm = commandLineVariables();

    if (!vm.count("dynamic_feature")) {
        LOG_ERROR(<< "Need: "
                  << "--dynamic_feature");
        return;
    }

    RuleSet rule_set; // level 0
    loadRulesFromFile(rules_file, rule_set);

    std::vector< EdsGraph > graphs;
    std::vector< int > random;
    loadGraphsFromFile(input_file, graphs, random);

    auto& dynamic_feature = vm["dynamic_feature"].as< std::string >();

    if (vm["model"].as< std::string >().find("generator") != std::string::npos) {
        LOG_INFO(<< "Training generator");
        std::unique_ptr< DynamicRuleGenerator >
            generator(new DynamicRuleGenerator(dynamic_feature, dynamic_feature,
                                               TransducerState::TRAIN,
                                               g_generator_k_best));
        int graph_count = random.size();
        for (int round = 1; round <= graph_count; ++round)
            generator->trainRuleGenerator(graphs[random[round - 1]], rule_set, round);
        generator->printInformation();

        generator->finishTraining();
    } else {
        LOG_INFO(<< "Training transducer");
        std::unique_ptr< BeamSearchTransducer >
            transducer(new BeamSearchTransducer(feature_input, feature_output,
                                                TransducerState::TRAIN,
                                                g_beam_size));
        transducer->setRuleSet(rule_set);
        // transducer->setGenerator(generator.get());

        int graph_count = random.size();
        for (int round = 1; round <= graph_count; ++round)
            transducer->train(graphs[random[round - 1]], round);
        transducer->printInformation();

        transducer->finishTraining();
    }
}

void Run::goldTest(const std::string& input_file,
                   const std::string& rules_file,
                   const std::string& feature_input) const {}

void Run::transduce(const std::string& input_file,
                    const std::string& output_file,
                    const std::string& rules_file,
                    const std::string& feature_file,
                    bool with_detail) const {
    EdsGraph graph;

    initialize();

    auto& vm = commandLineVariables();

    if (!vm.count("rules_extended") ||
        !vm.count("dynamic_feature") ||
        !vm.count("prototype_indeices")) {
        LOG_ERROR(<< "Need: "
                  << "--dynamic_feature, "
                  << "--rules_extended, "
                  << "--prototype_indeices");
        return;
    }

    RuleSet rule_set;
    RuleSet rule_set_extended;
    loadRulesFromFile(rules_file, rule_set);
    loadRulesFromFile(vm["rules_extended"].as< std::string >(),
                      rule_set_extended);

    std::ifstream is(input_file);
    std::ofstream os(output_file, std::ios::out);
    if (!is) {
        LOG_ERROR(<< "Can't find file " << input_file);
        return;
    }

    auto& dynamic_feature = vm["dynamic_feature"].as< std::string >();

    std::unique_ptr< BeamSearchTransducer >
        transducer(new BeamSearchTransducer(feature_file, feature_file,
                                            TransducerState::TRANSDUCE,
                                            g_beam_size));

    std::unique_ptr< DynamicRuleGenerator >
        generator(new DynamicRuleGenerator(dynamic_feature, dynamic_feature,
                                           TransducerState::TRAIN,
                                           g_generator_k_best));

    transducer->setRuleSet(rule_set);
    transducer->setRuleSet(rule_set_extended, true);
    transducer->setGenerator(generator.get());

    auto& indices = transducer->prototypeIndices();
    indices.resize(rule_set_extended.rules().size());
    loadPrototypeIndices(vm["prototype_indeices"].as< std::string >(),
                         indices);

    int round = 0, count = 0;
    std::vector< std::clock_t > succ_times;
    std::vector< std::clock_t > all_times;
    while (is >> graph && graph.nodes.size() < MAX_GRAPH_NODE_COUNT) {
        LOG_DEBUG(<< "Transduce: " << graph.filename);
        os << graph.filename << ": ";
        auto start_time = std::clock();
        bool succ = transducer->transduce(graph, with_detail);
        auto interval = std::clock() - start_time;
        if (succ) {
            succ_times.push_back(interval);
            all_times.push_back(interval);
            os << transducer->result();
            ++count;
        } else
            all_times.push_back(interval);
        os << '\n';
        if (with_detail && transducer->hasResultComputed()) {
            transducer->printResultToStream(os, true);
            os << '\n';
        }
        if (++round % 200 == 0)
            LOG_INFO(<< "Transduced " << ( float )count / round << ' '
                     << count << '/' << round);
    }

    LOG_INFO(<< "Transduced " << ( float )count / round << ' '
             << count << '/' << round);
    os.close();

    printStatistics(all_times, "all");
    printStatistics(succ_times, "successful");
}
}
