__author__ = 'thomas'
from argparse import ArgumentParser
from functools import reduce
import csv
import glob
import logging
import os
import random
import re
import string
import sys

from nltk.tokenize import word_tokenize
from sklearn.preprocessing import LabelEncoder
import numpy as np

from telicity.models.embedding import FastTextEmbedding
from telicity.util.data_util import VectorisedCrossLingualDataset
from telicity.util import io
from telicity.util import path_utils
from telicity.util import wittgenstein

PACKAGE_PATH = os.path.dirname(__file__)

parser = ArgumentParser()
parser.add_argument('-i', '--input-file', type=str, help='input file')
parser.add_argument('-i2', '--input-file-2', type=str, help='input file 2')
parser.add_argument('-ip', '--input-path', type=str, help='input path')
parser.add_argument('-ip2', '--input-path-2', type=str, help='input path 2')
parser.add_argument('-op', '--output-path', type=str, help='output path')
parser.add_argument('-op2', '--output-path-2', type=str, help='output path 2')
parser.add_argument('-op3', '--output-path-3', type=str, help='output path 3')
parser.add_argument('-o', '--output-file', type=str, help='output file')
parser.add_argument('-o2', '--output-file-2', type=str, help='output file 2')
parser.add_argument('-a', '--action', type=str, help='action to be executed', required=True)
parser.add_argument('-ef', '--experiment-file', type=str, help='experiment file to use')
parser.add_argument('-f', '--force-rewrite', action='store_true', help='force rewrite if exists')
parser.add_argument('-ll', '--log-level', type=str, default='INFO', help='logging level', choices=['CRITICAL', 'FATAL',
																								   'ERROR', 'WARNING',
																								   'WARN', 'INFO',
																								   'DEBUG'])
parser.add_argument('-xt', '--max-tokens', type=int, default=20, help='maximum number of tokens in sentence.')
parser.add_argument('-mt', '--min-tokens', type=int, default=4, help='minimum number of tokens in sentence.')
parser.add_argument('-ns', '--num-samples', type=int, default=50, help='number of samples from corpus.')
parser.add_argument('-rs', '--random-seed', type=int, default=29306, help='random seed')
parser.add_argument('-lp', '--log-path', type=str, help='path for file log')
parser.add_argument('-enc', '--encoding', type=str, default='utf-8', help='Encoding for processing embeddings.')
parser.add_argument('-sfl', '--skip-first-line', action='store_true', default=True,
					help='Skip header line when processing embeddings.')
parser.add_argument('-ed', '--expected-dim', type=int, default=300, help='Expected dimension of fastText embeddings.')


def sample_sentences(input_file, output_file, sample_size, max_tokens, min_tokens, random_seed):
	logging.info(f'Sampling {sample_size} sentences from {input_file} (min_tokens={min_tokens}; max_tokens={max_tokens})')

	random.seed(random_seed)

	logging.info('Processing sentences...')
	sentences = []
	with open(input_file) as in_file:
		for line in in_file:
			# Corpus is whitespace pre-tokenised
			tokens = line.strip().split()

			# Exclude punctuation from token count
			num_tokens = reduce(lambda x, y: x + y, map(lambda token: 1 if token not in string.punctuation else 0, tokens), 0)

			if (num_tokens >= min_tokens and num_tokens <= max_tokens):
				sentences.append(line.strip())
	logging.info(f'Extracted {len(sentences)} sentences!')

	random.shuffle(sentences)
	sample = sentences[:sample_size]

	logging.info(f'Storing {sample_size} at {output_file}...')
	with open(output_file, 'w') as out_file:
		for sent in sample:
			out_file.write(f'{sent}\n')
	logging.info('Finished!')


def process_cornell_movie_dialogues(movie_conversations_file, movie_lines_file, output_file_conversation,
									output_file_sentence):
	logging.info(f'Processing {movie_conversations_file} and {movie_lines_file}...')

	logging.info('Building movie lines hash table...')
	lines = {}
	with open(movie_lines_file, errors='replace') as movie_lines:
		for line in movie_lines:
			parts = line.strip().split(' +++$+++ ')
			sent = parts[-1].strip()
			# Add a full-stop to the sentence if there is no sentence end symbol (better for readability & further processing)
			if sent[-1] not in ['.', '?', '!']:
				sent += '.'
			lines[parts[0]] = sent
	logging.info(f'Collected {len(lines)} lines!')

	logging.info('Building dialogs...')
	num_conversations = 0
	num_sentences = 0
	with	open(movie_conversations_file, errors='replace') as movie_conversations, \
			open(output_file_conversation, 'w') as out_file_conv, \
			open(output_file_sentence, 'w') as out_file_sent:
		for line in movie_conversations:
			num_conversations += 1

			# The conversation is stored as a str-serialised python list (e.g. "['L194', 'L195', 'L196', 'L197']")
			conv_ids = eval(line.strip().split(' +++$+++ ')[-1])
			num_sentences += len(conv_ids)

			# Map the conv_ids to their actual lines and concatenate the conversation into a single str
			conversation = ' '.join(map(lambda conv_id: lines[conv_id], conv_ids))
			sentences = '\n'.join(map(lambda conv_id: lines[conv_id], conv_ids))

			out_file_conv.write(f'{conversation}\n')
			out_file_sent.write(f'{sentences}\n')
	logging.info(f'Wrote {num_conversations} conversations ({num_sentences} sentences)!')


def process_walking_around_corpus(input_path, output_file):
	logging.info(f'Processing file from {input_path}...')
	sents = set()
	for input_file in glob.glob(os.path.join(input_path, '*.csv')):
		with open(input_file) as in_file:
			csv_reader = csv.reader(in_file)
			for line in csv_reader:
				if len(line) == 3 and line[-1] != '':
					sent = re.sub(r'\s\s+', ' ', line[-1].replace('*', '').replace('(( ))', ' ').replace('(..)', ' ').replace('(.)', ' '))
					sents.add(sent)
	logging.info(f'Collected {len(sents)} sentences!')

	logging.info(f'Writing sentences to {output_file}...')
	with open(output_file, 'w') as out_file:
		for sent in sents:
			out_file.write(f'{sent}\n')
	logging.info('Finished!')


def process_ubuntu_chat_corpus(input_path, output_file):
	logging.info(f'Processing files from {input_path}...')
	sents = set()
	for d in os.listdir(input_path):
		try:
			_ = int(d)
		except ValueError:
			continue
		logging.info(f'\tProcessing directory {d}...')
		for input_file in glob.glob(os.path.join(input_path, d, '*.tsv')):
			with open(input_file) as in_file:
				csv_reader = csv.reader(in_file, delimiter='\t')
				for line in csv_reader:
					sents.add(line[-1])
	logging.info(f'Extracted {len(sents)} unique sentences!')

	logging.info(f'Storing sentences at {output_file}...')
	with open(output_file, 'w') as out_file:
		for sent in sents:
			out_file.write(f'{sent}\n')
	logging.info('Finished!')


def tokenize_sentences(input_file, output_file):
	logging.info(f'Tokenizing sentences from {input_file}...')
	with open(input_file) as in_file, open(output_file, 'w') as out_file:
		for idx, line in enumerate(in_file, 1):
			if idx % 10000 == 0: logging.info(f'\t{idx} sentences processed!')
			out_line = ' '.join(word_tokenize(line.strip()))
			out_file.write(f'{out_line}\n')
	logging.info('Finished!')


def process_fasttext_embeddings(fasttext_file, output_file, expected_dim=300, skip_first_line=True, encoding='utf-8'):
	logging.info(f'Processing fastText file from {fasttext_file} with expected_dim={expected_dim}, encoding={encoding} '
				 f'and skipping first line={skip_first_line}...')
	inverted_index = {}
	dims = []
	idx = 0
	num_errs = 0
	num_dupes = 0
	with open(fasttext_file, encoding=encoding) as in_file:
		if (skip_first_line): next(in_file)

		for idx, line in enumerate(in_file):
			if (idx % 100000 == 0): logging.debug(f'\t{idx} lines processed!')
			parts = line.strip().split()
			lex = parts[0]

			vec_dim = []
			try:
				vec_dim = list(map(lambda x: float(x), parts[1:]))
				if (expected_dim > 0 and len(vec_dim) == expected_dim):
					if lex in inverted_index:
						logging.warning(f'Item="{lex}" already exists at index={inverted_index[lex]}!')
						num_dupes += 1
					else:
						dims.append(vec_dim)
						inverted_index[lex] = len(inverted_index)
				else:
					logging.warning(f'Failed to parse line {idx}: {line}\n=====\nexpected_dim={expected_dim}, '
									f'vector_dim={len(vec_dim)}; lex={lex}')
			except ValueError as ex:
				logging.warning(f'Failed to parse line {idx}: {line}\n=====\nexpected_dim={expected_dim}, '
								f'vector_dim={len(vec_dim)}; lex={lex}')
				logging.warning(ex)
				num_errs += 1
	X = np.array(dims, dtype=np.float64)
	logging.info(f'Finished processing {idx} lines in {fasttext_file}; model.shape={X.shape}; len(inv_idx)={len(inverted_index)}; '
				 f'max(inv_idx)={max(inverted_index.values())}; num_errs={num_errs}; num_dupes={num_dupes}!')
	logging.info(f'Storing fastText embeddings at {output_file}...')
	io.save_structured_resource(inverted_index, os.path.join(output_file, 'inverted_index.dill'))
	io.numpy_to_hdf(X, output_file, 'X.hdf')
	logging.info(f'Finished!')


def prepare_cross_lingual_telicity_setup(embedding_files, dataset_files, dataset_load_functions, data_indices,
										 target_indices, skip_headers, lowercase, data_groups, output_file,
										 embedding_path, dataset_path, drop_target, delimiter):
	le = LabelEncoder()
	all_labels = {'atelic', 'state', 'telic'}
	drop_labels = {drop_target}
	le.fit(list(all_labels - drop_labels))

	inv_idx = {}
	data_phrase = []
	cross_val_data = []
	cross_val_labels = []
	additional_data = []
	additional_labels = []

	for embedding_file, dataset_file, dataset_load_fn, data_group, data_index, target_index, skip_header in zip(
			embedding_files.split('=='),
			dataset_files.split('=='),
			dataset_load_functions.split('=='),
			data_groups.split('=='),
			data_indices,
			target_indices,
			skip_headers):
		logging.info(f'Loading dataset with function={dataset_load_fn}...')
		dataset_loader = wittgenstein.create_function(dataset_load_fn)

		logging.info(f'\tembedding_file={embedding_file}; dataset_file={dataset_file}; dataset_load_fn={dataset_load_fn}; '
					  f'data_group={data_group}; data_index={data_index}; target_index={target_index}; '
					  f'skip_header={skip_header}; drop_target={drop_target}; delimiter="{delimiter}"')

		logging.info(f'Loading embeddings from {embedding_file}...')
		emb = FastTextEmbedding(embedding_path=os.path.join(embedding_path, embedding_file))
		data, targets = dataset_loader(os.path.join(dataset_path, dataset_file), data_index, target_index, skip_header,
									   drop_target, delimiter)
		oov = np.zeros((emb.dimensionality(),))
		labels = le.fit_transform(targets).tolist()

		logging.info(f'Labels for {embedding_file}: {labels}')

		for tokenised_phrase, label in zip(data, labels):
			phrase = ' '.join(tokenised_phrase).lower() if lowercase else ' '.join(tokenised_phrase)
			logging.debug(f'Data Group={data_group}; prhase={phrase}; label={label}')
			if data_group in ['cross_val', 'train']:
				cross_val_data.append(phrase)
				cross_val_labels.append(label)
			else:
				additional_data.append(phrase)
				additional_labels.append(label)

			denom_phrase = len(tokenised_phrase)
			x = np.zeros((emb.dimensionality(),))

			if phrase not in inv_idx:
				inv_idx[phrase] = len(inv_idx)

				for w in tokenised_phrase:
					token = w if not lowercase else w.lower()
					x += (emb.get(token, oov) / denom_phrase)
				data_phrase.append(x)
	X = np.array(data_phrase)

	logging.info(f'Created dataset with X.shape={X.shape} and len(inv_idx)={len(inv_idx)}!; '
				 f'cross_val_labels={np.bincount(np.array(cross_val_labels))}; '
				 f'additional_training_labels={np.bincount(np.array(additional_labels))}')

	ds = VectorisedCrossLingualDataset(cross_validation_data=cross_val_data,
									   cross_validation_labels=cross_val_labels,
									   additional_training_data=additional_data,
									   additional_training_labels=additional_labels,
									   label_encoder=le,
									   embeddings=(inv_idx, X))
	ds.to_file(output_file)


if (__name__ == '__main__'):
	args = parser.parse_args()

	log_formatter = logging.Formatter(fmt='%(asctime)s: %(levelname)s - %(message)s', datefmt='[%d/%m/%Y %H:%M:%S %p]')
	root_logger = logging.getLogger()
	root_logger.setLevel(getattr(logging, args.log_level))

	console_handler = logging.StreamHandler(sys.stdout)
	console_handler.setFormatter(log_formatter)
	root_logger.addHandler(console_handler)

	if (args.log_path is not None):
		timestamped_foldername = path_utils.timestamped_foldername()
		log_path = os.path.join(args.log_path, timestamped_foldername)

		if (not os.path.exists(log_path)):
			os.makedirs(log_path)

		file_handler = logging.FileHandler(os.path.join(log_path, f'corpora_preprocess_{args.action}.log'))
		file_handler.setFormatter(log_formatter)
		root_logger.addHandler(file_handler)

	if (args.output_path is not None and not os.path.exists(args.output_path)):
		os.makedirs(args.output_path)

	# Load experiment id file
	if (args.experiment_file is not None):
		with open(os.path.join(PACKAGE_PATH, 'resources', 'preprocessing', args.experiment_file), 'r') as csv_file:
			csv_reader = csv.reader(csv_file)
			experiments = []

			for line in csv_reader:
				experiments.append(line)

	if (args.action == 'sample_sentences'):
		sample_sentences(input_file=os.path.join(args.input_path, args.input_file),
						 output_file=os.path.join(args.output_path, args.output_file),
						 max_tokens=args.max_tokens, min_tokens=args.min_tokens,
						 sample_size=args.num_samples, random_seed=args.random_seed)
	elif (args.action == 'process_cornell_movie_dialogues'):
		process_cornell_movie_dialogues(movie_conversations_file=os.path.join(args.input_path, args.input_file),
										movie_lines_file=os.path.join(args.input_path_2, args.input_file_2),
										output_file_conversation=os.path.join(args.output_path, args.output_file),
										output_file_sentence=os.path.join(args.output_path_2, args.output_file_2))
	elif (args.action == 'process_walking_around_corpus'):
		process_walking_around_corpus(input_path=args.input_path, output_file=os.path.join(args.output_path, args.output_file))
	elif (args.action == 'process_ubuntu_chat_corpus'):
		process_ubuntu_chat_corpus(input_path=args.input_path, output_file=os.path.join(args.output_path, args.output_file))
	elif (args.action == 'tokenize_sentences'):
		for input_file, output_file in experiments:
			tokenize_sentences(input_file=os.path.join(args.input_path, input_file),
							   output_file=os.path.join(args.output_path, output_file))
	elif (args.action == 'process_fasttext_embeddings'):
		process_fasttext_embeddings(fasttext_file=os.path.join(args.input_path, args.input_file),
									output_file=os.path.join(args.output_path, args.output_file),
									expected_dim=args.expected_dim, encoding=args.encoding,
									skip_first_line=args.skip_first_line)
	elif (args.action == 'prepare_cross_lingual_telicity_setup'):
		for embedding_files, dataset_files, dataset_load_functions, data_indices, target_indices, skip_headers, \
			lowercase, data_groups, output_file, drop_target, delimiter in experiments:
			prepare_cross_lingual_telicity_setup(embedding_path=args.input_path, dataset_path=args.input_path_2,
												 output_file=os.path.join(args.output_path, output_file),
												 embedding_files=embedding_files, dataset_files=dataset_files,
												 dataset_load_functions=dataset_load_functions, drop_target=drop_target,
												 data_indices=list(map(lambda x: int(x), data_indices.split('=='))),
												 target_indices=list(map(lambda x: int(x), target_indices.split('=='))),
												 skip_headers=list(map(lambda x: x=='True', skip_headers.split('=='))),
												 lowercase=lowercase=='True', data_groups=data_groups,
												 delimiter=delimiter if not delimiter == 'tab' else '\t')

	else:
		raise ValueError(f'Action "{args.action}" not supported!')