package com.zhaohanphd;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Random;

import weka.classifiers.evaluation.Evaluation;
//import weka.classifiers.trees.REPTree;
import weka.core.Attribute;
import weka.core.Instances;
import weka.core.converters.CSVLoader;
import weka.filters.unsupervised.attribute.Remove;
import weka.filters.Filter;

public class ReferringFormDecisionTree {

	private static final String INPUT_FILE_PATH = "/home/zhao/Dropbox/mines-research/onr-giveness-hierarchy/annotation/intermediate-output/model-input.csv";

	private static HashMap<String, ArrayList<String>> results;
	
	public static void main(String[] args) throws Exception {
		results = new HashMap<String, ArrayList<String>>();
		String[] metrics = {"Accuracy", "RMSE", "Precision", "Recall", "F1 score", "Coverage", "Leaves"};
		for (int i = 0; i < metrics.length; i++) {
			results.put(metrics[i], new ArrayList<String>());
		}
		
		model("M1", "");
		model("M2", "Cognitive Status");
		model("M3", "Number of Distractors");
		model("M4", "Physical Distance (Instructor)");
		model("M5", "Recency");

		System.out.println("\n\n");
		
		for (int i = 0; i < metrics.length; i++) {
			if (metrics[i].equals("Coverage")) {
				System.out.println("\\hline");
			}
			System.out.print(metrics[i]);
			for (int j = 0; j < results.get(metrics[i]).size(); j++) {
				System.out.print(" & " + results.get(metrics[i]).get(j));
			}
			System.out.println("\\\\");
		}
	}
		

	public static void model(String name, String featureExclusion) throws Exception {
		System.out.print("\n\n\n" + name);
		if ( ! featureExclusion.equals("")) {
			System.out.println(" (no " + featureExclusion + ")");
		}
		else {
			System.out.println();
		}
		
		Instances data = getData(featureExclusion);
		setTargetVariable(data, "Referring Form");
		
		CustomREPTree tree = new CustomREPTree();
//		tree.setMaxDepth(6);
//		tree.setNoPruning(true);
		
//		stratified 15-fold cross-validation.
		evaluate(data, tree);

//		visualize tree
		tree.buildClassifier(data);

//		coverage
		Evaluation evaluation = new Evaluation(data);
		double[] predictions = evaluation.evaluateModel(tree, data);
		long coverage = Arrays.stream(predictions).distinct().count();
		outputMetric("Coverage", String.format("%d", coverage));
		
//		leaves
		
		outputMetric("Leaves", String.format("%d", tree.numLeaves()));
		
		visTree(tree, name, featureExclusion);
	}


	private static void visTree(CustomREPTree tree, String name, String featureExclusion) throws Exception, FileNotFoundException, IOException {
		String dot_file_name = String.format("tree-%s.dot", name);
		try (PrintWriter out = new PrintWriter(dot_file_name)) {
		    String dotContent = tree.graph();
		    dotContent = dotContent.replace("Cognitive Status", "Cog.\nStatus");
		    dotContent = dotContent.replace("Recency", "Temp.\nDist.");
		    dotContent = dotContent.replace("Number of Distractors", "Distractors");
		    dotContent = dotContent.replace("Physical Distance", "Phys.\nDist.");
		    dotContent = dotContent.replace(" (Instructor)", "");
		    dotContent = dotContent.replace(" (", "\n (");
		    dotContent = dotContent.replace("= ", "");
		    dotContent = dotContent.replace("[\\d]", "");
		    
		    for (int i = 50; i > 0; i--) {
			    dotContent = dotContent.replace(i + ": ", "");
			    dotContent = dotContent.replace(i + " : ", "");
			}
		    
			out.println(dotContent);
		}
		String vis_file_ext = "pdf";
		String vis_file_name = "tree-" + name + "." + vis_file_ext ; 
		Runtime.getRuntime().exec(new String[] { "/bin/sh", "-c", 
				"dot -T" + vis_file_ext + " " + dot_file_name + " > " + vis_file_name});
	}

	private static void evaluate(Instances data, CustomREPTree tree) throws Exception {
		CustomEvaluation eval = new CustomEvaluation(data);
		eval.crossValidateModel(tree, data, 15, new Random(1));
		
		String metricName, metric;
		
		metricName = "Accuracy";
		metric =  String.format("%.2f", eval.pctCorrect());
		outputMetric(metricName, metric);

//		System.out.printf("RMSE	    	%.3f\n", eval.errorRate());
		metricName = "RMSE";
		metric =  String.format("%.3f", eval.errorRate());
		outputMetric(metricName, metric);
		
		
//		System.out.printf("Precision	%.3f\n", eval.weightedPrecision());
		metricName = "Precision";
		metric =  String.format("%.3f", eval.weightedPrecision());
		outputMetric(metricName, metric);
		
		
//		System.out.printf("Recall   	%.3f\n", eval.weightedRecall());
		metricName = "Recall";
		metric =  String.format("%.3f", eval.weightedRecall());
		outputMetric(metricName, metric);
		
		
//		System.out.printf("F1 score 	%.3f\n", eval.weightedFMeasure());
		metricName = "F1 score";
		metric =  String.format("%.3f", eval.weightedFMeasure());
		outputMetric(metricName, metric);
	}

	private static void outputMetric(String metricName, String metric) throws Exception {
		System.out.printf("%-10s %s\n", metricName, metric);
		results.get(metricName).add(metric);
	}

	private static Instances getData(String featureExclusion) throws Exception {
//		load csv
		CSVLoader csvLoader = new CSVLoader();
		csvLoader.setSource(new File(INPUT_FILE_PATH));
		Instances data = csvLoader.getDataSet();
//        System.out.println("This file has " + data.numInstances() + " examples.\n");

//        printAttributes(data);
                
        // exclude specific attribute
        if (featureExclusion != null && ! featureExclusion.equals("")) {

            int featureExclusionIndex = -1;
            for(int i = 0; i < data.instance(0).numAttributes();i++ ) {
            	Attribute attribute = data.instance(0).attribute(i);
            	if (attribute.name().equals(featureExclusion)) {
            		System.out.println("To remove " + attribute.name());
            		featureExclusionIndex = i + 1;
            		break;
            	}
            }

            if (featureExclusionIndex != -1) {
            	
            	Remove remove = new Remove();
//            	remove.input(data);
            	String attributeIndices = String.valueOf(featureExclusionIndex);
            	remove.setAttributeIndices(attributeIndices);
            	remove.setInputFormat(data);
            	data = Filter.useFilter(data, remove);
            }
		}

        
//      reorder
//        Reorder reorder = new Reorder();
//        reorder.setAttributeIndices("3,1,4,5,6,2");
//        reorder.setInputFormat(data);
//        data = Filter.useFilter(data, reorder);
        
        printAttributes(data);
        
		return data;
	}

	private static void setTargetVariable(Instances data, String string) {
		// set target variable
	    for(int i = 0; i < data.instance(0).numAttributes();i++ ){
	        Attribute attribute = data.instance(0).attribute(i);
	        if (attribute.name().equals("Referring Form")) {
	        	data.setClassIndex(i);
//	        	System.out.println("\nTarget variable is set to \"" + attribute.name() + "\"");
	        	break;
			}
	    }
	}

	private static void printAttributes(Instances data) {
		System.out.println("Attributes:");
        for(int i = 0; i < data.instance(0).numAttributes();i++ ) {
        	Attribute attribute = data.instance(0).attribute(i);
            System.out.println(attribute);
        }
	}

}
