__author__ = 'thomas'
from argparse import ArgumentParser
import csv
import logging
import os
import sys

from sacred import Experiment
from sacred.observers import FileStorageObserver
from sacred.observers import MongoObserver
from sacred.observers import SqlObserver
from sacred.observers import TinyDbObserver
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelEncoder
import numpy as np

from telicity.util import io
from telicity.util import path_utils
from telicity.util import wittgenstein


PACKAGE_PATH = os.path.dirname(__file__)
PROJECT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))

parser = ArgumentParser()
parser.add_argument('-i', '--input-file', type=str, help='input file')
parser.add_argument('-ip', '--input-path', type=str, help='path to input file')
parser.add_argument('-i2', '--input-file-2', type=str, help='input file 2')
parser.add_argument('-ip2', '--input-path-2', type=str, help='path to input file 2')
parser.add_argument('-op', '--output-path', type=str, help='path to output file')
parser.add_argument('-cn', '--config-name', type=str, required=True, help='name of config')
parser.add_argument('-ef', '--experiment-file', type=str, required=True)
parser.add_argument('-s', '--store-scores', action='store_true', help='Store individual scores in file.')
parser.add_argument('-sp', '--score-path', type=str, help='path to store the score file.')
parser.add_argument('-eid', '--experiment-id', type=int, default=-1, help='experiment id to use.')
parser.add_argument('-obs', '--sacred-observers', nargs='+', type=str, default=['mongo', 'telegram', 'sqlite'],
					help='mongo observers to add')
parser.add_argument('-ll', '--log-level', type=str, default='INFO', help='logging level', choices=['CRITICAL', 'FATAL',
																								   'ERROR', 'WARNING',
																								   'WARN', 'INFO',
																								   'DEBUG'])
parser.add_argument('-lp', '--log-path', type=str, help='path for file log')

ex = Experiment('multilingual_aspect')
OBSERVERS = {
	'mongo': MongoObserver.create(db_name='MultilingualAspect'),
	'sqlite': SqlObserver.create('sqlite:///{}'.format(os.path.join(PACKAGE_PATH, 'resources', 'sacred', 'MultilingualAspect.sqlite'))),
	'tinydb': TinyDbObserver.create(os.path.join(PACKAGE_PATH, 'resources', 'sacred', 'MultilingualAspect.tinydb')),
	'file': FileStorageObserver.create(os.path.join(PACKAGE_PATH, 'resources', 'sacred', 'MultilingualAspect.fs'))
}


@ex.config
def config():
	config_name = ''
	exp_name = ''
	dataset_file = ''
	lowercase = False
	error_analysis_output_path = ''
	num_folds = -1
	random_seed = -1
	dataset_load_function = ''
	data_index = -1
	target_index = -1
	skip_header = False
	experiment_id = -1


@ex.main
def run(config_name, dataset_file, lowercase, error_analysis_output_path, exp_name, num_folds, random_seed,
		dataset_load_function, data_index, target_index, skip_header, experiment_id):
	logging.info(f'Running experiments with config_name={config_name}...')

	logging.info(f'Preparing data for experiment from {dataset_file}...')
	dataset_loader = wittgenstein.create_function(dataset_load_function)
	data, targets = dataset_loader(dataset_file, data_index, target_index, skip_header)
	logging.info(f'Loaded data with len={len(data)} and len(targets)={len(targets)} and targets={set(targets)}!')

	le = LabelEncoder()
	vec = CountVectorizer(lowercase=lowercase, tokenizer=lambda x: x.split())

	logging.info('Encoding data...')
	y = le.fit_transform(targets)
	X = vec.fit_transform(data)
	logging.info(f'Encoded data with shape X.shape={X.shape} and labels with shape={y.shape}!')

	logging.info(f'Starting {num_folds}-cross validation...')
	skf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=random_seed)

	accs = []
	f1y1s = []
	f1y0s = []
	pred_all = []
	gold_all = []
	for idx, (train_idx, test_idx) in enumerate(skf.split(X, y), 1):
		logging.debug(f'K-Fold split {idx}/{num_folds}...')
		X_train, X_test = X[train_idx], X[test_idx]
		y_train, y_test = y[train_idx], y[test_idx]

		logging.debug('\tFitting model...')
		lr = LogisticRegression(solver='lbfgs', multi_class='multinomial')
		lr.fit(X_train, y_train)
		y_pred = lr.predict(X_test)

		pred_all.extend(y_pred.tolist())
		gold_all.extend(y_test.tolist())
		acc = accuracy_score(y_test, y_pred)
		f1y1 = f1_score(y_test, y_pred, pos_label=1)
		f1y0 = f1_score(y_test, y_pred, pos_label=0)
		accs.append(acc)
		f1y1s.append(f1y1)
		f1y0s.append(f1y0)
		logging.debug(f'\tAccuracy={acc}; F1["{le.inverse_transform(np.array([0]))}"]={f1y0}; '
					  f'F1["{le.inverse_transform(np.array([1]))}"]={f1y1}')
		logging.debug('\t---------------------------------------------')

	logging.info(f'Average Accuracy: {np.average(accs)} (+/- {np.std(accs)}); '
				 f'Average F1["{le.inverse_transform(np.array([0]))}"]={np.average(f1y0s)} (+/- {np.std(f1y0s)}); '
				 f'Average F1["{le.inverse_transform(np.array([1]))}"]={np.average(f1y1s)} (+/- {np.std(f1y1s)})')

	logging.info('Running analysis...')
	cm = confusion_matrix(np.array(gold_all), np.array(pred_all))

	cross_val_analysis = {}
	cross_val_analysis['y_test'] = le.inverse_transform(np.array(y_test)).tolist()
	cross_val_analysis['y_pred'] = le.inverse_transform(np.array(y_pred)).tolist()
	cross_val_analysis['test_phrases'] = data

	out_file_suffix = f'{exp_name}-{experiment_id}'
	cm_wrapped = {
		'labels': le.classes_,
		'data': cm
	}
	io.save_structured_resource(cm_wrapped, os.path.join(error_analysis_output_path, f'confusion_matrix_{out_file_suffix}.dill'))
	logging.info('Finished!')

	return np.average(accs)


if (__name__ == '__main__'):
	args = parser.parse_args()

	log_formatter = logging.Formatter(fmt='%(asctime)s: %(levelname)s - %(message)s', datefmt='[%d/%m/%Y %H:%M:%S %p]')
	root_logger = logging.getLogger()
	root_logger.setLevel(getattr(logging, args.log_level))

	console_handler = logging.StreamHandler(sys.stdout)
	console_handler.setFormatter(log_formatter)
	root_logger.addHandler(console_handler)

	if (args.log_path is not None):
		timestamped_foldername = path_utils.timestamped_foldername()
		log_path = os.path.join(args.log_path, timestamped_foldername)

		if (not os.path.exists(log_path)):
			os.makedirs(log_path)

		file_handler = logging.FileHandler(os.path.join(log_path, f'experiment_lr_cross_domain_{args.config_name}.log'))
		file_handler.setFormatter(log_formatter)
		root_logger.addHandler(file_handler)

	if (args.output_path is not None and not os.path.exists(args.output_path)):
		os.makedirs(args.output_path)

	for obs in args.sacred_observers:
		ex.observers.append(OBSERVERS[obs])
	ex.logger = root_logger

	# Load experiment id file
	with open(os.path.join(PACKAGE_PATH, 'resources', 'experiments', args.experiment_file), 'r') as csv_file:
		csv_reader = csv.reader(csv_file)
		experiments = []

		for line in csv_reader:
			experiments.append(line)

		if (args.experiment_id > 0):
			experiments = [experiments[args.experiment_id - 1]]
	for experiment_id, (dataset_file, lowercase, num_folds, random_seed, dataset_load_function, data_index,
						target_index, skip_header) in enumerate(experiments, 1):
		config_dict = {
			'dataset_file': os.path.join(args.input_path, dataset_file),
			'lowercase': lowercase.lower()=='true',
			'config_name': args.config_name,
			'exp_name': args.config_name,
			'error_analysis_output_path': args.output_path,
			'random_seed': int(random_seed),
			'num_folds': int(num_folds),
			'dataset_load_function': dataset_load_function,
			'data_index': int(data_index),
			'target_index': int(target_index),
			'skip_header': skip_header.lower()=='true',
			'experiment_id': experiment_id
		}
		ex.run(config_updates=config_dict)
		logging.info('--------------------------------------------------')
