__author__ = 'thomas'
from argparse import ArgumentParser
from itertools import zip_longest
import collections
import csv
import logging
import os
import sys

from sacred import Experiment
from sacred.observers import FileStorageObserver
from sacred.observers import MongoObserver
from sacred.observers import SqlObserver
from sacred.observers import TelegramObserver
from sacred.observers import TinyDbObserver
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.model_selection import KFold
from sklearn.preprocessing import LabelEncoder
from wolkenatlas.embedding import Embedding
import numpy as np

from semantx.util import io
from semantx.util import path_utils


PACKAGE_PATH = os.path.dirname(__file__)
PROJECT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))

parser = ArgumentParser()
parser.add_argument('-i', '--input-file', type=str, help='input file')
parser.add_argument('-ip', '--input-path', type=str, required=True, help='path to input file')
parser.add_argument('-i2', '--input-file-2', type=str, help='input file 2')
parser.add_argument('-ip2', '--input-path-2', type=str, required=True, help='path to input file 2')
parser.add_argument('-i3', '--input-file-3', type=str, help='input file 3')
parser.add_argument('-ip3', '--input-path-3', type=str, help='path to input file 3')
parser.add_argument('-op', '--output-path', type=str, help='path to output file')
parser.add_argument('-op2', '--output-path-2', type=str, help='path to output file 2')
parser.add_argument('-cn', '--config-name', type=str, required=True, help='name of config')
parser.add_argument('-ef', '--experiment-file', type=str, required=True)
parser.add_argument('-s', '--store-scores', action='store_true', help='Store individual scores in file.')
parser.add_argument('-sp', '--score-path', type=str, help='path to store the score file.')
parser.add_argument('-eid', '--experiment-id', type=int, default=-1, help='experiment id to use.')
parser.add_argument('-obs', '--sacred-observers', nargs='+', type=str, default=['mongo', 'telegram', 'sqlite'],
					help='mongo observers to add')
parser.add_argument('-ll', '--log-level', type=int, default=logging.INFO, help='logging level')
parser.add_argument('-bl', '--blacklist', type=str, nargs='+', default=['punct'])
parser.add_argument('-wl', '--whitelist', type=str, nargs='+', default=[])

ex = Experiment('verb_telicity')
OBSERVERS = {
	'mongo': MongoObserver.create(db_name='VerbTelicity'),
	#'telegram': TelegramObserver.from_config(os.path.join(PROJECT_PATH, 'resources/sacred/telegram.json')),
	'sqlite': SqlObserver.create('sqlite:///{}'.format(os.path.join(PACKAGE_PATH, 'resources', 'sacred', 'VerbTelicity.sqlite'))),
	#'tinydb': TinyDbObserver.create(os.path.join(PACKAGE_PATH, 'resources', 'sacred', 'VerbTelicity.tinydb')),
	'file': FileStorageObserver.create(os.path.join(PACKAGE_PATH, 'resources', 'sacred', 'VerbTelicity.fs'))
}


@ex.config
def config():
	config_name = ''
	exp_name = ''
	vector_file = ''
	context_type = ''
	lowercase = False
	context_file = ''
	error_analysis_output_path = ''
	blacklist = []
	whitelist = []
	lemma = False
	statistical_significance_output_path = ''
	target_labels = []
	num_splits = 0


@ex.main
def run(config_name, vector_file, context_type, context_file, lowercase, error_analysis_output_path, num_splits,
		exp_name, blacklist, whitelist, lemma, statistical_significance_output_path, target_labels):
	logging.info('Running experiments with config_name={} (and exp_name={}) with vector_file={}...'.format(config_name, vector_file, exp_name))
	emb = Embedding(model_file=vector_file)
	logging.info('VSM loaded!')

	logging.info('Loading context file from {}...'.format(context_file))
	context = io.load_structured_resource(context_file)
	logging.info('Context file loaded!')

	le = LabelEncoder()
	le.fit(np.array(target_labels))

	accuracies = []
	f1s_0 = []
	f1s_1 = []
	correct_data = collections.defaultdict(int)
	correct_deps = collections.defaultdict(int)
	correct_pos = collections.defaultdict(int)
	correct_pos_per_class = collections.defaultdict(lambda: collections.defaultdict(int))

	incorrect_data = collections.defaultdict(int)
	incorrect_deps = collections.defaultdict(int)
	incorrect_pos = collections.defaultdict(int)
	incorrect_pos_per_class = collections.defaultdict(lambda: collections.defaultdict(int))
	sent_dict = collections.defaultdict(lambda: collections.defaultdict(list))

	data = collections.defaultdict(list)
	deps = collections.defaultdict(list)
	pos = collections.defaultdict(list)
	oovs = set()

	pred_all = {}
	gold_all = {}

	idx = np.array(list(context.keys()))
	kf = KFold(n_splits=num_splits, shuffle=True, random_state=29306)
	f1_avg = 'weighted' if len(target_labels) > 2 else 'binary'

	for fold_idx, (train_idx, test_idx) in enumerate(kf.split(idx), 1):
		train_data = []
		train_labels = []
		test_data = []
		test_labels = []

		for ti in idx[train_idx]:

			label = context[ti]['label']
			if label not in target_labels: continue

			verb = context[ti]['lemma_word_map'][0 if lemma else 1]
			verb = verb.lower() if lowercase else verb
			x = emb[verb]

			if context_type != 'none':
				ctx_list = context[ti][context_type]
				for ctx in ctx_list:
					if isinstance(ctx, tuple):
						token_tag = ctx[1]
					else:
						token_tag = ctx

					if ('_###POS###_' in token_tag):
						token, _ = token_tag.split('_###POS###_')
					else:
						token = token_tag

					w = token if not lowercase else token.lower()
					x = x + emb[w]

			train_data.append(x)
			train_labels.append(label)

		X_train = np.array(train_data)
		y_train = le.transform(np.array(train_labels))

		for ti in idx[test_idx]:
			label = context[ti]['label']
			if label not in target_labels: continue

			verb = context[ti]['lemma_word_map'][0 if lemma else 1]
			verb = verb.lower() if lowercase else verb
			x = emb[verb]

			pos_verb = context[ti]['target_word_pos']

			if context_type != 'none':
				ctx_list = context[ti][context_type]
				logging.debug(f'Length of context: {len(ctx_list)}!')
				if len(ctx_list) <= 0:
					data[ti].append('__NO_CONTEXT__')
					if (context_type.startswith('dep')):
						deps[ti].append('__NO_CONTEXT__')
				else:
					for ctx in ctx_list:
						if (isinstance(ctx, tuple)):
							if ('_###POS###_' in ctx[1]):
								token, tag = ctx[1].split('_###POS###_')
							else:
								token = ctx[1]
								tag = None
							in_whitelist = len(whitelist) <= 0 or (len(whitelist) > 0 and ctx[0] in whitelist)
							not_in_blacklist = ctx[0] not in blacklist
							if (all([in_whitelist, not_in_blacklist])):
								w = token if not lowercase else token.lower()  # word2vec doesn't care about the syntax
								if (w not in emb):
									oovs.add(w)
								x = x + emb[w]
								data[ti].append(w if w in emb else '_OOV_')
								deps[ti].append(ctx[0])
								if (tag is not None):
									pos[ti].append(tag)
							else:
								data[ti].append('__NO_CONTEXT__')
								deps[ti].append('__NO_CONTEXT__')
								if (tag is not None):
									pos[ti].append('__NO_CONTEXT__')
						else:
							if ('_###POS###_' in ctx):
								token, tag = ctx.split('_###POS###_')
							else:
								token = ctx
								tag = None
							w = token if not lowercase else token.lower()
							if (w not in emb):
								oovs.add(w)
							x = x + emb[w]
							data[ti].append(w if w in emb else '_OOV_')
							if (tag is not None):
								pos[ti].append(tag)

			else:
				data[ti].append('__NO_CONTEXT__')
				pos[ti].append(context[ti]['target_word_pos'])

			test_data.append(x)
			test_labels.append(label)

		X_test = np.array(test_data)
		y_test = le.transform(np.array(test_labels))

		lr = LogisticRegression(solver='liblinear')
		lr.fit(X_train, y_train)
		y_pred = lr.predict(X_test)

		acc = accuracy_score(y_test, y_pred)
		accuracies.append(acc)

		f1_0 = f1_score(y_test, y_pred, pos_label=0, average=f1_avg)
		f1_1 = f1_score(y_test, y_pred, pos_label=1, average=f1_avg)
		f1s_0.append(f1_0)
		f1s_1.append(f1_1)
		logging.info(f'Performance for fold={fold_idx}: Accuracy={acc}; '
					 f'F1["{le.inverse_transform(np.array([0]))}"]={f1_0}; '
					 f'F1["{le.inverse_transform(np.array([1]))}"]={f1_1}')

		for example_idx, pred, gold in zip(idx[test_idx], y_pred, y_test):
			gold_all[example_idx] = gold
			pred_all[example_idx] = pred

	logging.info(f'---- {context_type} Average Performance: Accuracy={np.average(accuracies)} (+/- {np.std(accuracies)}); '
				 f'F1["{le.inverse_transform(np.array([0]))}"]={np.average(f1s_0)} (+/- {np.std(f1s_0)}); '
				 f'F1["{le.inverse_transform(np.array([1]))}"]={np.average(f1s_1)} (+/- {np.std(f1s_1)}) ----')

	# Error analysis
	for example_idx in gold_all.keys():
		p_i = pred_all[example_idx]
		g_i = gold_all[example_idx]
		lbl = le.inverse_transform(np.array([g_i]))[0]
		logging.debug('\tstr(g_i)={}'.format(lbl))

		if (p_i == g_i):
			for item in data[example_idx]:
				correct_data[item] += 1
			for item in deps[example_idx]:
				correct_deps[item] += 1
			for item in pos[example_idx]:
				correct_pos[item] += 1
				correct_pos_per_class[lbl][item] += 1

			for lex, dep, tag in zip_longest(data[example_idx], deps[example_idx], pos[example_idx], fillvalue='_eps_'):
				if (lex in ['you', 'me', 'that', 'she', 'her', 'into', 'would', 'back', 'had', 'his', 'on', 'i', 'he', 'not', 'we']):
					sent_dict[lex]['idx'].append(idx)
					sent_dict[lex]['dep'].append(dep)
					sent_dict[lex]['pos'].append(tag)
					sent_dict[lex]['label'].append(le.inverse_transform(np.array([g_i])))
		else:
			for item in data[example_idx]:
				incorrect_data[item] += 1
			for item in deps[example_idx]:
				incorrect_deps[item] += 1
			for item in pos[example_idx]:
				incorrect_pos[item] += 1
				incorrect_pos_per_class[lbl][item] += 1


	logging.info('Storing error analysis dicts at {}...'.format(error_analysis_output_path))
	io.save_structured_resource(correct_data,
								os.path.join(error_analysis_output_path, 'correct_data_{}.dill'.format(context_type)))
	io.save_structured_resource(correct_deps,
								os.path.join(error_analysis_output_path, 'correct_deps_{}.dill'.format(context_type)))
	io.save_structured_resource(correct_pos,
								os.path.join(error_analysis_output_path, 'correct_pos_{}.dill'.format(context_type)))
	io.save_structured_resource(incorrect_data,
								os.path.join(error_analysis_output_path, 'incorrect_data_{}.dill'.format(context_type)))
	io.save_structured_resource(incorrect_deps,
								os.path.join(error_analysis_output_path, 'incorrect_deps_{}.dill'.format(context_type)))
	io.save_structured_resource(incorrect_pos,
								os.path.join(error_analysis_output_path, 'incorrect_pos_{}.dill'.format(context_type)))
	io.save_structured_resource(correct_pos_per_class, os.path.join(error_analysis_output_path,
																	'correct_pos_per_class_{}.dill'.format(
																		context_type)))
	io.save_structured_resource(incorrect_pos_per_class, os.path.join(error_analysis_output_path,
																	  'incorrect_pos_per_class_{}.dill'.format(
																		  context_type)))
	io.save_structured_resource(sent_dict,
								os.path.join(error_analysis_output_path, 'sent_dict_{}.dill'.format(context_type)))
	io.save_structured_resource(oovs, os.path.join(error_analysis_output_path, 'oovs_{}.dill'.format(context_type)))
	logging.info('Error analysis dicts stored!')

	logging.info('Storing statistical significance resources at {}...'.format(statistical_significance_output_path))
	io.save_structured_resource(pred_all, os.path.join(statistical_significance_output_path, 'model_predictions_{}.dill'.format(context_type)))
	io.save_structured_resource(gold_all, os.path.join(statistical_significance_output_path, 'gold_labels_{}.dill'.format(context_type)))
	logging.info('Statistical signficance resources stored!')

	label_0 = le.inverse_transform(np.array([0]))
	label_1 = le.inverse_transform(np.array([1]))
	results = {
		'Accuracy': (np.average(accuracies), np.std(accuracies)),
		f'F1-{label_0}': (np.average(f1s_0), np.std(f1s_0)),
		f'F1-{label_1}': (np.average(f1s_1), np.std(f1s_1))
	}
	return results


if (__name__ == '__main__'):
	args = parser.parse_args()

	timestamped_foldername = path_utils.timestamped_foldername()
	log_path = os.path.join(path_utils.get_log_path(), timestamped_foldername)
	if (not os.path.exists(log_path)):
		os.makedirs(log_path)

	log_formatter = logging.Formatter(fmt='%(asctime)s: %(levelname)s - %(message)s', datefmt='[%d/%m/%Y %H:%M:%S %p]')
	root_logger = logging.getLogger()
	root_logger.setLevel(args.log_level)

	file_handler = logging.FileHandler(os.path.join(log_path, 'experiment_telicity_lr_{}.log'.format(args.config_name)))
	file_handler.setFormatter(log_formatter)
	root_logger.addHandler(file_handler)

	console_handler = logging.StreamHandler(sys.stdout)
	console_handler.setFormatter(log_formatter)
	root_logger.addHandler(console_handler)

	if (args.output_path is not None and not os.path.exists(args.output_path)):
		os.makedirs(args.output_path)

	if (args.output_path_2 is not None and not os.path.exists(args.output_path_2)):
		os.makedirs(args.output_path_2)

	for obs in args.sacred_observers:
		ex.observers.append(OBSERVERS[obs])
	ex.logger = root_logger

	# Load experiment id file
	with open(os.path.join(PACKAGE_PATH, 'resources', 'experiments', args.experiment_file), 'r') as csv_file:
		csv_reader = csv.reader(csv_file)
		experiments = []

		for line in csv_reader:
			experiments.append(line)

		if (args.experiment_id > 0):
			experiments = [experiments[args.experiment_id - 1]]

	for input_file, context_file, context_type, lowercase, whitelist, target_labels, lemma, num_splits in experiments:
		logging.info('Running with context_type={}, lowercase={} and target_labels={}...'.format(
			context_type, lowercase, target_labels)
		)
		config_dict = {
			'vector_file': os.path.join(args.input_path, input_file),
			'context_type': context_type,
			'lowercase': lowercase=='True',
			'context_file': os.path.join(args.input_path_2, context_file),
			'config_name': args.config_name,
			'exp_name': args.config_name,
			'error_analysis_output_path': args.output_path,
			'blacklist': args.blacklist,
			'whitelist': args.whitelist if whitelist == 'none' else whitelist.split('-'),
			'statistical_significance_output_path': args.output_path_2,
			'lemma': lemma=='True',
			'target_labels': target_labels.split('-'),
			'num_splits': int(num_splits)
		}
		ex.run(config_updates=config_dict)
		logging.info('--------------------------------------------------')
