__author__ = 'thomas'
from argparse import ArgumentParser
import csv
import logging
import os
import sys

from sacred import Experiment
from sacred.observers import FileStorageObserver
from sacred.observers import MongoObserver
from sacred.observers import SqlObserver
from sklearn.externals import joblib
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.model_selection import StratifiedKFold
import numpy as np

from telicity.util import path_utils


PACKAGE_PATH = os.path.dirname(__file__)
PROJECT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))

parser = ArgumentParser()
parser.add_argument('-i', '--input-file', type=str, help='input file')
parser.add_argument('-ip', '--input-path', type=str, help='path to input file')
parser.add_argument('-i2', '--input-file-2', type=str, help='input file 2')
parser.add_argument('-ip2', '--input-path-2', type=str, help='path to input file 2')
parser.add_argument('-op', '--output-path', type=str, help='path to output file')
parser.add_argument('-cn', '--config-name', type=str, required=True, help='name of config')
parser.add_argument('-ef', '--experiment-file', type=str, required=True)
parser.add_argument('-s', '--store-scores', action='store_true', help='Store individual scores in file.')
parser.add_argument('-sp', '--score-path', type=str, help='path to store the score file.')
parser.add_argument('-eid', '--experiment-id', type=int, default=-1, help='experiment id to use.')
parser.add_argument('-obs', '--sacred-observers', nargs='+', type=str, default=['file'],
					help='mongo observers to add')
parser.add_argument('-ll', '--log-level', type=str, default='INFO', help='logging level', choices=['CRITICAL', 'FATAL',
																								   'ERROR', 'WARNING',
																								   'WARN', 'INFO',
																								   'DEBUG'])
parser.add_argument('-lp', '--log-path', type=str, help='path for file log')

ex = Experiment('multilingual_aspect')
OBSERVERS = {
	'mongo': MongoObserver.create(db_name='MultilingualAspect'),
	'sqlite': SqlObserver.create('sqlite:///{}'.format(os.path.join(PACKAGE_PATH, 'resources', 'sacred', 'MultilingualAspect.sqlite'))),
	'file': FileStorageObserver.create(os.path.join(PACKAGE_PATH, 'resources', 'sacred', 'MultilingualAspect.fs'))
}


@ex.config
def config():
	config_name = ''
	exp_name = ''
	data_file = ''
	label_file = ''
	label_encoder_file = ''
	num_folds = -1
	random_seed = -1
	experiment_id = -1


@ex.main
def run(config_name, data_file, label_file, label_encoder_file, exp_name, num_folds, random_seed,
		experiment_id):
	logging.info(f'Running experiments with config_name={config_name} and with data_file={data_file}')
	X = np.loadtxt(data_file)
	y = np.loadtxt(label_file)
	le = joblib.load(label_encoder_file)

	logging.info(f'Data and labels loaded!')

	logging.info(f'Starting {num_folds}-cross validation...')
	skf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=random_seed)

	accs = []
	pred_all = []
	gold_all = []
	for idx, (train_idx, test_idx) in enumerate(skf.split(X, y), 1):
		logging.debug(f'K-Fold split {idx}/{num_folds}...')
		X_train, X_test = X[train_idx], X[test_idx]
		y_train, y_test = y[train_idx], y[test_idx]

		logging.debug('\tFitting model...')
		lr = LogisticRegression(solver='liblinear')
		lr.fit(X_train, y_train)
		y_pred = lr.predict(X_test)

		pred_all.extend(y_pred.tolist())
		gold_all.extend(y_test.tolist())
		acc = accuracy_score(y_test, y_pred)
		accs.append(acc)
		logging.debug(f'\tAccuracy={acc}')
		logging.debug('\t---------------------------------------------')

	avg_acc = np.average(accs)
	std_acc = np.std(accs)
	if len(le.classes_) > 2:
		f1_weighted = f1_score(np.array(gold_all), np.array(pred_all), average='weighted')
		f1_micro = f1_score(np.array(gold_all), np.array(pred_all), average='micro')
		f1_macro = f1_score(np.array(gold_all), np.array(pred_all), average='macro')

		logging.info(f'Average Accuracy: {avg_acc} (+/- {std_acc}); '
					 f'F1 weighted={f1_weighted}; F1 micro={f1_micro}; F1 macro={f1_macro}')
	else:
		f1_y1 = f1_score(np.array(gold_all), np.array(pred_all), pos_label=1)
		f1_y0 = f1_score(np.array(gold_all), np.array(pred_all), pos_label=0)
		logging.info(f'Average Accuracy: {np.average(accs)} +/- {np.std(accs)}; '
					 f'Total F1["{le.inverse_transform(np.array([0]))}"]={f1_y0}; '
					 f'Total F1["{le.inverse_transform(np.array([1]))}"]={f1_y1}')

	return np.average(accs)


if (__name__ == '__main__'):
	args = parser.parse_args()

	log_formatter = logging.Formatter(fmt='%(asctime)s: %(levelname)s - %(message)s', datefmt='[%d/%m/%Y %H:%M:%S %p]')
	root_logger = logging.getLogger()
	root_logger.setLevel(getattr(logging, args.log_level))

	console_handler = logging.StreamHandler(sys.stdout)
	console_handler.setFormatter(log_formatter)
	root_logger.addHandler(console_handler)

	if (args.log_path is not None):
		timestamped_foldername = path_utils.timestamped_foldername()
		log_path = os.path.join(args.log_path, timestamped_foldername)

		if (not os.path.exists(log_path)):
			os.makedirs(log_path)

		file_handler = logging.FileHandler(os.path.join(log_path, f'experiment_lr_{args.config_name}.log'))
		file_handler.setFormatter(log_formatter)
		root_logger.addHandler(file_handler)

	if (args.output_path is not None and not os.path.exists(args.output_path)):
		os.makedirs(args.output_path)

	for obs in args.sacred_observers:
		ex.observers.append(OBSERVERS[obs])
	ex.logger = root_logger

	# Load experiment id file
	with open(os.path.join(PACKAGE_PATH, 'resources', 'experiments', args.experiment_file), 'r') as csv_file:
		csv_reader = csv.reader(csv_file)
		experiments = []

		for line in csv_reader:
			experiments.append(line)

		if (args.experiment_id > 0):
			experiments = [experiments[args.experiment_id - 1]]

	for experiment_id, (data_file, label_file, label_encoder_file, num_folds, random_seed) in enumerate(experiments, 1):
		config_dict = {
			'data_file': os.path.join(args.input_path, data_file),
			'label_file': os.path.join(args.input_path, label_file),
			'label_encoder_file': os.path.join(args.input_path, label_encoder_file),
			'config_name': args.config_name,
			'exp_name': args.config_name,
			'random_seed': int(random_seed),
			'num_folds': int(num_folds),
			'experiment_id': experiment_id
		}
		ex.run(config_updates=config_dict)
		logging.info('--------------------------------------------------')
