__author__ = 'thomas'
from argparse import ArgumentParser
import collections
import csv
import logging
import os
import sys

from sacred import Experiment
from sacred.observers import FileStorageObserver
from sacred.observers import MongoObserver
from sacred.observers import SqlObserver
from sacred.observers import TinyDbObserver
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelEncoder
import numpy as np

from telicity.util import io
from telicity.util import path_utils
from telicity.util import wittgenstein
from telicity.models.embedding import FastTextEmbedding


PACKAGE_PATH = os.path.dirname(__file__)
PROJECT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))

parser = ArgumentParser()
parser.add_argument('-i', '--input-file', type=str, help='input file')
parser.add_argument('-ip', '--input-path', type=str, help='path to input file')
parser.add_argument('-i2', '--input-file-2', type=str, help='input file 2')
parser.add_argument('-ip2', '--input-path-2', type=str, help='path to input file 2')
parser.add_argument('-op', '--output-path', type=str, help='path to output file')
parser.add_argument('-cn', '--config-name', type=str, required=True, help='name of config')
parser.add_argument('-ef', '--experiment-file', type=str, required=True)
parser.add_argument('-s', '--store-scores', action='store_true', help='Store individual scores in file.')
parser.add_argument('-sp', '--score-path', type=str, help='path to store the score file.')
parser.add_argument('-eid', '--experiment-id', type=int, default=-1, help='experiment id to use.')
parser.add_argument('-obs', '--sacred-observers', nargs='+', type=str, default=['file'],
					help='mongo observers to add')
parser.add_argument('-ll', '--log-level', type=str, default='INFO', help='logging level', choices=['CRITICAL', 'FATAL',
																								   'ERROR', 'WARNING',
																								   'WARN', 'INFO',
																								   'DEBUG'])
parser.add_argument('-lp', '--log-path', type=str, help='path for file log')

ex = Experiment('multilingual_aspect')
OBSERVERS = {
	'mongo': MongoObserver.create(db_name='MultilingualAspect'),
	'sqlite': SqlObserver.create('sqlite:///{}'.format(os.path.join(PACKAGE_PATH, 'resources', 'sacred', 'MultilingualAspect.sqlite'))),
	'file': FileStorageObserver.create(os.path.join(PACKAGE_PATH, 'resources', 'sacred', 'MultilingualAspect.fs'))
}


@ex.config
def config():
	config_name = ''
	exp_name = ''
	vector_file = ''
	dataset_file = ''
	lowercase = False
	error_analysis_output_path = ''
	num_folds = -1
	random_seed = -1
	dataset_load_function = ''
	data_index = -1
	target_index = -1
	skip_header = False
	drop_target = None
	experiment_id = -1
	evaluation_mode = ''


@ex.main
def run(config_name, vector_file, dataset_file, lowercase, error_analysis_output_path, exp_name, num_folds, random_seed,
		dataset_load_function, data_index, target_index, skip_header, experiment_id, evaluation_mode, drop_target):
	logging.info(f'Running experiments with config_name={config_name} and with vector_file={vector_file} '
				 f'and evaluation_mode={evaluation_mode} and drop_target={drop_target}...')
	emb = FastTextEmbedding(embedding_path=vector_file)
	logging.info(f'fastText embeddings loaded!')

	logging.info(f'Preparing data for experiment from {dataset_file}...')
	dataset_loader = wittgenstein.create_function(dataset_load_function)
	logging.debug(f'Loading data with data_index={data_index}; target_index={target_index}; skip_header={skip_header}; '
				  f'drop_target={drop_target}')
	data, targets = dataset_loader(dataset_file, data_index, target_index, skip_header, drop_target)
	logging.debug(f'Loaded data with len={len(data)} and len(targets)={len(targets)}!')

	le = LabelEncoder()
	oov = np.zeros((emb.dimensionality(),))

	if evaluation_mode == 'cross_validation':
		logging.info(f'Starting {num_folds}-cross validation...')
		skf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=random_seed)

		logging.info('Encoding labels...')
		y = le.fit_transform(targets)
		logging.info(f'{len(le.classes_)} Labels ({le.classes_}) encoded!')

		# Majority class "predictions"
		y_majority_class = np.full(y.shape, np.argmax(np.bincount(y)))
		acc_majority_class = accuracy_score(y, y_majority_class)

		logging.info('Encoding data...')
		data_phrase = []
		for phrase in data:
			denom_phrase = len(phrase)
			x = np.zeros((emb.dimensionality(),))
			for w in phrase:
				token = w if not lowercase else w.lower()
				x += (emb.get(token, oov) / denom_phrase)
			data_phrase.append(x)
		X = np.array(data_phrase)
		logging.info(f'Encoded data with shape X.shape={X.shape} and labels with shape={y.shape}!')

		accs = []
		pred_all = []
		gold_all = []
		for idx, (train_idx, test_idx) in enumerate(skf.split(X, y), 1):
			logging.debug(f'K-Fold split {idx}/{num_folds}...')
			X_train, X_test = X[train_idx], X[test_idx]
			y_train, y_test = y[train_idx], y[test_idx]

			logging.debug('\tFitting model...')
			lr = LogisticRegression(solver='lbfgs', multi_class='multinomial')
			lr.fit(X_train, y_train)
			y_pred = lr.predict(X_test)

			pred_all.extend(y_pred.tolist())
			gold_all.extend(y_test.tolist())
			acc = accuracy_score(y_test, y_pred)
			accs.append(acc)
			logging.debug(f'\tAccuracy={acc}')
			logging.debug('\t---------------------------------------------')

		f1_y1 = f1_score(np.array(gold_all), np.array(pred_all), pos_label=1)
		f1_y0 = f1_score(np.array(gold_all), np.array(pred_all), pos_label=0)
		logging.info(f'Average Accuracy: {np.average(accs)} +/- {np.std(accs)}; '
					 f'Majority class baseline Accuracy: {acc_majority_class}; '
					 f'Total F1["{le.inverse_transform(np.array([0]))}"]={f1_y0}; '
					 f'Total F1["{le.inverse_transform(np.array([1]))}"]={f1_y1}')
	else:
		logging.info(f'Starting train/test split evaluation...')

		logging.info('Encoding labels...')
		y_test = le.fit_transform(targets[1])
		y_train = le.transform(targets[0])
		logging.info(f'{len(le.classes_)} Labels ({le.classes_}) encoded!')

		logging.info('Encoding training data...')
		data_phrase = []
		for phrase in data[0]:
			denom_phrase = len(phrase)
			x = np.zeros((emb.dimensionality(),))
			for w in phrase:
				token = w if not lowercase else w.lower()
				x += (emb.get(token, oov) / denom_phrase)
			data_phrase.append(x)
		X_train = np.array(data_phrase)

		logging.info('Encoding testing data...')
		data_phrase = []
		for phrase in data[0]:
			denom_phrase = len(phrase)
			x = np.zeros((emb.dimensionality(),))
			for w in phrase:
				token = w if not lowercase else w.lower()
				x += (emb.get(token, oov) / denom_phrase)
			data_phrase.append(x)
		X_test = np.array(data_phrase)
		logging.info(f'Encoded training data with shape X_train.shape={X_train.shape} and'
					 f' testing data with shape X_test={X_test.shape}!')

		lr = LogisticRegression(solver='lbfgs', multi_class='multinomial')
		lr.fit(X_train, y_train)
		y_pred = lr.predict(X_test)

		acc = accuracy_score(y_test, y_pred)

		# Slightly silly, but fewest changes necessary this way...
		gold_all = y_test.tolist()
		pred_all = y_pred.tolist()

		logging.info(f'Accuracy: {acc}; Majority class baseline Accuracy: {acc_majority_class}')

	logging.info('Running analysis...')
	cm = confusion_matrix(np.array(gold_all), np.array(pred_all))

	correct_per_word = collections.defaultdict(int)
	totals_per_word = collections.defaultdict(int)
	for yi_pred, yi_test, phrase in zip(pred_all, gold_all, data):
		for w in phrase:
			totals_per_word[w] += 1

		if yi_pred == yi_test: # Correct prediction
			for w in phrase:
				correct_per_word[w] += 1

	cross_val_analysis = {}
	cross_val_analysis['y_test'] = le.inverse_transform(np.array(y_test)).tolist()
	cross_val_analysis['y_pred'] = le.inverse_transform(np.array(y_pred)).tolist()
	cross_val_analysis['test_phrases'] = data

	accuracy_per_context_word = {}
	counts_per_context_word = {}
	for w in totals_per_word.keys():
		accuracy_per_context_word[w] = correct_per_word.get(w, 0) / totals_per_word[w]
		counts_per_context_word[w] = (correct_per_word.get(w, 0), totals_per_word[w])
	logging.info(f'Storing analysis resources at {error_analysis_output_path}...')
	out_file_suffix = f'{exp_name}-{experiment_id}'
	io.save_structured_resource(accuracy_per_context_word, os.path.join(error_analysis_output_path,
																		f'accuracy_by_word_{out_file_suffix}.dill'))
	io.save_structured_resource(counts_per_context_word, os.path.join(error_analysis_output_path,
																	  f'counts_by_word_{out_file_suffix}.dill'))
	io.save_structured_resource(counts_per_context_word, os.path.join(error_analysis_output_path,
																	  f'cross_va_analysis_{out_file_suffix}.dill'))
	cm_wrapped = {
		'labels': le.classes_,
		'data': cm
	}
	io.save_structured_resource(cm_wrapped, os.path.join(error_analysis_output_path, f'confusion_matrix_{out_file_suffix}.dill'))
	logging.info('Finished!')

	return np.average(accs)


if (__name__ == '__main__'):
	args = parser.parse_args()

	log_formatter = logging.Formatter(fmt='%(asctime)s: %(levelname)s - %(message)s', datefmt='[%d/%m/%Y %H:%M:%S %p]')
	root_logger = logging.getLogger()
	root_logger.setLevel(getattr(logging, args.log_level))

	console_handler = logging.StreamHandler(sys.stdout)
	console_handler.setFormatter(log_formatter)
	root_logger.addHandler(console_handler)

	if (args.log_path is not None):
		timestamped_foldername = path_utils.timestamped_foldername()
		log_path = os.path.join(args.log_path, timestamped_foldername)

		if (not os.path.exists(log_path)):
			os.makedirs(log_path)

		file_handler = logging.FileHandler(os.path.join(log_path, f'experiment_lr_{args.config_name}.log'))
		file_handler.setFormatter(log_formatter)
		root_logger.addHandler(file_handler)

	if (args.output_path is not None and not os.path.exists(args.output_path)):
		os.makedirs(args.output_path)

	for obs in args.sacred_observers:
		ex.observers.append(OBSERVERS[obs])
	ex.logger = root_logger

	# Load experiment id file
	with open(os.path.join(PACKAGE_PATH, 'resources', 'experiments', args.experiment_file), 'r') as csv_file:
		csv_reader = csv.reader(csv_file)
		experiments = []

		for line in csv_reader:
			experiments.append(line)

		if (args.experiment_id > 0):
			experiments = [experiments[args.experiment_id - 1]]

	for experiment_id, (vector_file, dataset_file, lowercase, num_folds, random_seed, dataset_load_function, data_index,
						target_index, skip_header, drop_target, evaluation_mode) in enumerate(experiments, 1):
		config_dict = {
			'vector_file': os.path.join(args.input_path, vector_file),
			'dataset_file': os.path.join(args.input_path_2, dataset_file),
			'lowercase': lowercase.lower()=='true',
			'config_name': args.config_name,
			'exp_name': args.config_name,
			'error_analysis_output_path': args.output_path,
			'random_seed': int(random_seed),
			'num_folds': int(num_folds),
			'dataset_load_function': dataset_load_function,
			'data_index': int(data_index),
			'target_index': int(target_index),
			'skip_header': skip_header.lower()=='true',
			'drop_target': drop_target if drop_target != 'None' else None,
			'experiment_id': experiment_id,
			'evaluation_mode': evaluation_mode
		}
		ex.run(config_updates=config_dict)
		logging.info('--------------------------------------------------')
