__author__ = 'thomas'
from argparse import ArgumentParser
import csv
import logging
import os
import sys

from sacred import Experiment
from sacred.observers import FileStorageObserver
from sacred.observers import MongoObserver
from sacred.observers import SqlObserver
from sacred.observers import TinyDbObserver
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import StratifiedKFold
from sklearn.utils import shuffle
import numpy as np

from telicity.util import io
from telicity.util import path_utils
from telicity.util.data_util import VectorisedCrossLingualDataset


PACKAGE_PATH = os.path.dirname(__file__)
PROJECT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))

parser = ArgumentParser()
parser.add_argument('-i', '--input-file', type=str, help='input file')
parser.add_argument('-ip', '--input-path', type=str, help='path to input file')
parser.add_argument('-i2', '--input-file-2', type=str, help='input file 2')
parser.add_argument('-ip2', '--input-path-2', type=str, help='path to input file 2')
parser.add_argument('-op', '--output-path', type=str, help='path to output file')
parser.add_argument('-cn', '--config-name', type=str, required=True, help='name of config')
parser.add_argument('-ef', '--experiment-file', type=str, required=True)
parser.add_argument('-s', '--store-scores', action='store_true', help='Store individual scores in file.')
parser.add_argument('-sp', '--score-path', type=str, help='path to store the score file.')
parser.add_argument('-eid', '--experiment-id', type=int, default=-1, help='experiment id to use.')
parser.add_argument('-obs', '--sacred-observers', nargs='+', type=str, default=['file'],
					help='mongo observers to add')
parser.add_argument('-ll', '--log-level', type=str, default='INFO', help='logging level', choices=['CRITICAL', 'FATAL',
																								   'ERROR', 'WARNING',
																								   'WARN', 'INFO',
																								   'DEBUG'])
parser.add_argument('-lp', '--log-path', type=str, help='path for file log')

ex = Experiment('multilingual_aspect')
OBSERVERS = {
	'mongo': MongoObserver.create(db_name='MultilingualAspect'),
	'sqlite': SqlObserver.create('sqlite:///{}'.format(os.path.join(PACKAGE_PATH, 'resources', 'sacred', 'MultilingualAspect.sqlite'))),
	'file': FileStorageObserver.create(os.path.join(PACKAGE_PATH, 'resources', 'sacred', 'MultilingualAspect.fs'))
}


@ex.config
def config():
	config_name = ''
	exp_name = ''
	dataset_file = ''
	error_analysis_output_path = ''
	num_folds = -1
	random_seed = -1
	experiment_id = -1
	evaluation_mode = ''


@ex.main
def run(config_name, dataset_file, error_analysis_output_path, exp_name, num_folds, random_seed, experiment_id,
		evaluation_mode):
	logging.info(f'Running experiments with config_name={config_name} and with dataset_file={dataset_file} '
				 f'and evaluation_mode={evaluation_mode}...')
	dataset = VectorisedCrossLingualDataset.from_file(dataset_file)
	logging.info(f'Vectorised Dataset loaded!')

	out_file_suffix = f'{exp_name}-{experiment_id}'

	if evaluation_mode == 'cross_validation':
		logging.info(f'Starting {num_folds}-cross validation...')
		skf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=random_seed)

		# Majority class "predictions"
		y = dataset.cross_validation_labels
		y_majority_class = np.full(y.shape, np.argmax(np.bincount(y)))
		acc_majority_class = accuracy_score(y, y_majority_class)

		_, fname = os.path.split(dataset_file)
		_, train_langs, test_langs = fname.split('.')[0].rsplit('_', 2)


		accs = []
		pred_all = []
		gold_all = []
		cross_val_analysis = []
		for idx, (train_idx, test_idx) in enumerate(skf.split(dataset.cross_validation_data, y), 1):
			logging.debug(f'K-Fold split {idx}/{num_folds}...')
			curr_analysis = {}
			train_phrases, test_phrases = dataset.cross_validation_data[train_idx], dataset.cross_validation_data[test_idx]

			train_data = []
			for phrase in np.concatenate((train_phrases, dataset.additional_training_data)):
				train_data.append(dataset.embeddings[phrase])
			X_train = np.array(train_data)
			y_train = np.concatenate((y[train_idx], dataset.additional_training_labels))
			X_train, y_train = shuffle(X_train, y_train)
			logging.debug(f'\tX_train.shape={X_train.shape}; y_train.shape={y_train.shape}!')

			test_data = []
			for phrase in test_phrases:
				test_data.append(dataset.embeddings[phrase])
			X_test = np.array(test_data)
			y_test = y[test_idx]
			logging.debug(f'\tX_test.shape={X_test.shape}; y_test.shape={y_test.shape}; class_balance={np.bincount(y_train)/len(y_train)}!')

			logging.debug('\tFitting model...')
			lr = LogisticRegression(solver='lbfgs', multi_class='multinomial')
			lr.fit(X_train, y_train)
			y_pred = lr.predict(X_test)

			curr_analysis['y_test'] = dataset.label_encoder.inverse_transform(y_test).tolist()
			curr_analysis['y_pred'] = dataset.label_encoder.inverse_transform(y_pred).tolist()
			curr_analysis['test_phrases'] = test_phrases.tolist()
			cross_val_analysis.append(curr_analysis)

			pred_all.extend(y_pred.tolist())
			gold_all.extend(y_test.tolist())
			acc = accuracy_score(y_test, y_pred)
			accs.append(acc)
			logging.debug(f'\tAccuracy={acc}')
			logging.debug('\t---------------------------------------------')

		f1_y1 = f1_score(np.array(gold_all), np.array(pred_all), pos_label=1)
		f1_y0 = f1_score(np.array(gold_all), np.array(pred_all), pos_label=0)
		logging.info(f'Average Accuracy[{train_langs}-{test_langs}]: {np.average(accs)} +/- {np.std(accs)}; '
					 f'Majority class baseline Accuracy: {acc_majority_class}; '
					 f'Total F1["{dataset.label_encoder.inverse_transform(np.array([0]))}"]={f1_y0}; '
					 f'Total F1["{dataset.label_encoder.inverse_transform(np.array([1]))}"]={f1_y1}')

		io.save_structured_resource(cross_val_analysis, os.path.join(error_analysis_output_path, f'cross_val_analysis_{out_file_suffix}.dill'))
	else:
		logging.info(f'Starting train/test split evaluation...')

		# Majority class "predictions"
		y_test = dataset.test_labels()
		y_majority_class = np.full(y_test.shape, np.argmax(np.bincount(y_test)))
		acc_majority_class = accuracy_score(y_test, y_majority_class)

		train_data = []
		for phrase in dataset.train_data():
			train_data.append(dataset.embeddings[phrase])
		X_train = np.array(train_data)
		y_train = dataset.train_labels()

		test_data = []
		for phrase in dataset.test_data():
			test_data.append(dataset.embeddings[phrase])
		X_test = np.array(test_data)
		logging.info(f'X_train.shape={X_train.shape}, y_train.shape={y_train.shape}, X_test.shape={X_test.shape} '
					 f'y_test.shape={y_test.shape}')

		logging.debug('\tFitting model...')
		lr = LogisticRegression(solver='lbfgs', multi_class='multinomial')
		lr.fit(X_train, y_train)
		y_pred = lr.predict(X_test)

		acc = accuracy_score(y_test, y_pred)

		f1_y1 = f1_score(y_test, y_pred, pos_label=1)
		f1_y0 = f1_score(y_test, y_pred, pos_label=0)
		logging.info(f'Accuracy: {acc} ; '
					 f'Majority class baseline Accuracy: {acc_majority_class}; '
					 f'Total F1["{dataset.label_encoder.inverse_transform(np.array([0]))}"]={f1_y0}; '
					 f'Total F1["{dataset.label_encoder.inverse_transform(np.array([1]))}"]={f1_y1}')

		# Slightly silly, but fewest changes necessary this way...
		gold_all = y_test.tolist()
		pred_all = y_pred.tolist()

		logging.info(f'Accuracy: {acc}; Majority class baseline Accuracy: {acc_majority_class}')
		accs = [acc]

	logging.info('Running analysis...')
	cm = confusion_matrix(np.array(gold_all), np.array(pred_all))

	'''
	correct_per_word = collections.defaultdict(int)
	totals_per_word = collections.defaultdict(int)
	for yi_pred, yi_test, phrase in zip(pred_all, gold_all, data):
		for w in phrase:
			totals_per_word[w] += 1

		if yi_pred == yi_test: # Correct prediction
			for w in phrase:
				correct_per_word[w] += 1

	accuracy_per_context_word = {}
	counts_per_context_word = {}
	for w in totals_per_word.keys():
		accuracy_per_context_word[w] = correct_per_word.get(w, 0) / totals_per_word[w]
		counts_per_context_word[w] = (correct_per_word.get(w, 0), totals_per_word[w])
	logging.info(f'Storing analysis resources at {error_analysis_output_path}...')
	io.save_structured_resource(accuracy_per_context_word, os.path.join(error_analysis_output_path,
																		f'accuracy_by_word_{out_file_suffix}.dill'))
	io.save_structured_resource(counts_per_context_word, os.path.join(error_analysis_output_path,
																	  f'counts_by_word_{out_file_suffix}.dill'))
																	  '''
	cm_wrapped = {
		'labels': dataset.label_encoder.classes_,
		'data': cm
	}
	io.save_structured_resource(cm_wrapped, os.path.join(error_analysis_output_path, f'confusion_matrix_{out_file_suffix}.dill'))
	logging.info('Finished!')

	return np.average(accs)


if (__name__ == '__main__'):
	args = parser.parse_args()

	log_formatter = logging.Formatter(fmt='%(asctime)s: %(levelname)s - %(message)s', datefmt='[%d/%m/%Y %H:%M:%S %p]')
	root_logger = logging.getLogger()
	root_logger.setLevel(getattr(logging, args.log_level))

	console_handler = logging.StreamHandler(sys.stdout)
	console_handler.setFormatter(log_formatter)
	root_logger.addHandler(console_handler)

	if (args.log_path is not None):
		timestamped_foldername = path_utils.timestamped_foldername()
		log_path = os.path.join(args.log_path, timestamped_foldername)

		if (not os.path.exists(log_path)):
			os.makedirs(log_path)

		file_handler = logging.FileHandler(os.path.join(log_path, f'experiment_lr_{args.config_name}.log'))
		file_handler.setFormatter(log_formatter)
		root_logger.addHandler(file_handler)

	if (args.output_path is not None and not os.path.exists(args.output_path)):
		os.makedirs(args.output_path)

	for obs in args.sacred_observers:
		ex.observers.append(OBSERVERS[obs])
	ex.logger = root_logger

	# Load experiment id file
	with open(os.path.join(PACKAGE_PATH, 'resources', 'experiments', args.experiment_file), 'r') as csv_file:
		csv_reader = csv.reader(csv_file)
		experiments = []

		for line in csv_reader:
			experiments.append(line)

		if (args.experiment_id > 0):
			experiments = [experiments[args.experiment_id - 1]]
	for experiment_id, (dataset_file, num_folds, random_seed, evaluation_mode) in enumerate(experiments, 1):
		config_dict = {
			'dataset_file': os.path.join(args.input_path, dataset_file),
			'config_name': args.config_name,
			'exp_name': args.config_name,
			'error_analysis_output_path': args.output_path,
			'random_seed': int(random_seed),
			'num_folds': int(num_folds),
			'experiment_id': experiment_id,
			'evaluation_mode': evaluation_mode
		}
		ex.run(config_updates=config_dict)
		logging.info('--------------------------------------------------')
