import csv
import os

from dataclasses import dataclass
import numpy as np
import spacy_udpipe

from telicity.models.embedding import FastTextEmbedding
from telicity.util import io
from telicity.util.tokeniser import tokenise_udpipe


@dataclass
class VectorisedCrossLingualDataset:
	cross_validation_data:		object
	cross_validation_labels:	object
	additional_training_data:	object
	additional_training_labels:	object
	label_encoder:				object
	embeddings:					object

	# Reuse/Abuse an existing data structure, but fiddle with the semantics of its members (what could possibly go wrong?)
	def train_data(self):
		return self.cross_validation_data

	def train_labels(self):
		return self.cross_validation_labels

	def test_data(self):
		return self.additional_training_data

	def test_labels(self):
		return self.additional_training_labels

	@staticmethod
	def from_file(path, return_as_np_arrays=True):
		if os.path.exists(os.path.join(path, 'cross_validation_data.hdf')):
			cross_val_data = io.hdf_to_numpy(path, 'cross_validation_data.hdf')
		else:
			cross_val_data = io.load_structured_resource(os.path.join(path, 'cross_validation_data.dill'))
			if return_as_np_arrays:
				cross_val_data = np.array(cross_val_data)
		if os.path.exists(os.path.join(path, 'cross_validation_labels.hdf')):
			cross_val_labels = io.hdf_to_numpy(path, 'cross_validation_labels.hdf')
		else:
			cross_val_labels = io.load_structured_resource(os.path.join(path, 'cross_validation_labels.dill'))
			if return_as_np_arrays:
				cross_val_labels = np.array(cross_val_labels)
		if os.path.exists(os.path.join(path, 'additional_training_data.hdf')):
			additional_data = io.hdf_to_numpy(path, 'additional_training_data.hdf')
		else:
			additional_data = io.load_structured_resource(os.path.join(path, 'additional_training_data.dill'))
			if return_as_np_arrays:
				additional_data = np.array(additional_data)
		if os.path.exists(os.path.join(path, 'additional_training_labels.hdf')):
			additional_labels = io.hdf_to_numpy(path, 'additional_training_labels.hdf')
		else:
			additional_labels = io.load_structured_resource(os.path.join(path, 'additional_training_labels.dill'))
			if return_as_np_arrays:
				additional_labels = np.array(additional_labels)

		return VectorisedCrossLingualDataset(
			cross_validation_data=cross_val_data,
			cross_validation_labels=cross_val_labels,
			additional_training_data=additional_data,
			additional_training_labels=additional_labels,
			label_encoder=io.load_structured_resource(os.path.join(path, 'label_encoder.dill')),
			embeddings=FastTextEmbedding(embedding_path=os.path.join(path, 'embeddings.kvec'))
		)

	def to_file(self, path, store_as_np_arrays=False):
		if not os.path.exists(path):
			os.makedirs(path)
		if store_as_np_arrays:
			io.numpy_to_hdf(self.cross_validation_data, path, 'cross_validation_data.hdf')
			io.numpy_to_hdf(self.cross_validation_labels, path, 'cross_validation_labels.hdf')
			io.numpy_to_hdf(self.additional_training_data, path, 'additional_training_data.hdf')
			io.numpy_to_hdf(self.additional_training_labels, path, 'additional_training_labels.hdf')
		else:
			io.save_structured_resource(self.cross_validation_data, os.path.join(path, 'cross_validation_data.dill'))
			io.save_structured_resource(self.cross_validation_labels, os.path.join(path, 'cross_validation_labels.dill'))
			io.save_structured_resource(self.additional_training_data, os.path.join(path, 'additional_training_data.dill'))
			io.save_structured_resource(self.additional_training_labels, os.path.join(path, 'additional_training_labels.dill'))
		io.save_structured_resource(self.label_encoder, os.path.join(path, 'label_encoder.dill'))
		if (isinstance(self.embeddings, tuple)):
			io.save_structured_resource(self.embeddings[0],
										os.path.join(path, 'embeddings.kvec', 'inverted_index.dill'))
			io.numpy_to_hdf(self.embeddings[1], os.path.join(path, 'embeddings.kvec'), 'X.hdf')
		else:
			io.save_structured_resource(self.embeddings.inv_idx_,
										os.path.join(path, 'embeddings.kvec', 'inverted_index.dill'))
			io.numpy_to_hdf(self.embeddings.X_, os.path.join(path, 'embeddings.kvec'), 'X.hdf')


def get_telicity_dataset_en(dataset_file, data_idx, target_idx, skip_header=True):
	try:
		nlp = spacy_udpipe.load('en')
	except Exception as _:
		spacy_udpipe.download('en')
		nlp = spacy_udpipe.load('en')

	data = []
	targets = []
	with open(dataset_file) as in_file:
		csv_reader = csv.reader(in_file, delimiter='\t')
		if skip_header: next(in_file)

		for line in csv_reader:
			if line[target_idx] != 'error':
				tokens = []
				doc = nlp(line[data_idx].strip())
				for sent in doc.sents:
					for token in sent:
						tokens.append(token.text)

				data.append(tokens)
				targets.append(line[target_idx].strip())

	return data, targets


def get_telicity_dataset_tr(dataset_file, data_idx, target_idx, skip_header=True, drop_target=None, delimiter='\t',
							tokenize=True):
	try:
		nlp = spacy_udpipe.load('tr')
	except Exception as _:
		spacy_udpipe.download('tr')
		nlp = spacy_udpipe.load('tr')

	data = []
	targets = []
	with open(dataset_file) as in_file:
		if skip_header: next(in_file)
		csv_reader = csv.reader(in_file, delimiter=delimiter)

		for line in csv_reader:
			if len(line) <= 0: continue

			drop = drop_target == line[target_idx]
			if not drop:
				tokens = []
				doc = nlp(line[data_idx].strip())
				for sent in doc.sents:
					for token in sent:
						tokens.append(token.text)

				if tokenize:
					data.append(tokens)
				else:
					data.append(doc.text)
				targets.append(line[target_idx].strip())

	return data, targets


def get_telicity_dataset_ar(dataset_file, data_idx, target_idx, skip_header=True, drop_target=None, delimiter='\t',
							tokenize=True):
	try:
		nlp = spacy_udpipe.load('ar')
	except Exception as _:
		spacy_udpipe.download('ar')
		nlp = spacy_udpipe.load('ar')

	data = []
	targets = []
	with open(dataset_file) as in_file:
		csv_reader = csv.reader(in_file, delimiter=delimiter, quoting=csv.QUOTE_NONE)
		if skip_header: next(in_file)

		for line in csv_reader:
			err = line[target_idx] == 'NA'
			drop = drop_target == line[target_idx]

			if not err and not drop:
				tokens = []
				doc = nlp(line[data_idx].strip())
				for sent in doc.sents:
					for token in sent:
						tokens.append(token.text)

				if tokenize:
					data.append(tokens)
				else:
					data.append(doc.text)
				targets.append(line[target_idx].strip())

	return data, targets


def get_telicity_dataset_fa(dataset_file, data_idx, target_idx, skip_header=True, drop_target=None, delimiter=',',
							tokenize=True):
	try:
		nlp = spacy_udpipe.load('fa')
	except Exception as _:
		spacy_udpipe.download('fa')
		nlp = spacy_udpipe.load('fa')

	data = []
	targets = []
	with open(dataset_file) as in_file:
		csv_reader = csv.reader(in_file, delimiter=delimiter, quoting=csv.QUOTE_NONE)
		if skip_header: next(in_file)

		for line in csv_reader:
			err = line[target_idx] == 'NA'
			drop = drop_target == line[target_idx]

			if not err and not drop:
				tokens = []
				doc = nlp(line[data_idx].strip())
				for sent in doc.sents:
					for token in sent:
						tokens.append(token.text)

				if tokenize:
					data.append(tokens)
				else:
					data.append(doc.text)
				targets.append(line[target_idx].strip())

	return data, targets


def get_telicity_dataset_de(dataset_file, data_idx, target_idx, skip_header=True, drop_target=None, delimiter='\t',
							tokenize=True):
	try:
		nlp = spacy_udpipe.load('de')
	except Exception as _:
		spacy_udpipe.download('de')
		nlp = spacy_udpipe.load('de')

	data = []
	targets = []
	with open(dataset_file) as in_file:
		csv_reader = csv.reader(in_file, delimiter=delimiter, quoting=csv.QUOTE_NONE)
		if skip_header: next(in_file)

		for line in csv_reader:
			err = line[target_idx] == 'error'
			drop = drop_target == line[target_idx]

			if not err and not drop:
				tokens = []
				doc = nlp(line[data_idx].strip())
				for sent in doc.sents:
					for token in sent:
						tokens.append(token.text)

				if tokenize:
					data.append(tokens)
				else:
					data.append(doc.text)
				targets.append(line[target_idx].strip())

	return data, targets


def get_telicity_dataset_zh(dataset_file, data_idx, target_idx, skip_header=True, drop_target=None, delimiter=',',
							tokenize=True):
	try:
		nlp = spacy_udpipe.load('zh')
	except Exception as _:
		spacy_udpipe.download('zh')
		nlp = spacy_udpipe.load('zh')

	data = []
	targets = []
	with open(dataset_file, encoding="utf8") as in_file:
		csv_reader = csv.reader(in_file, delimiter=delimiter, quoting=csv.QUOTE_NONE)
		if skip_header: next(in_file)

		for line in csv_reader:
			drop = drop_target == line[target_idx]

			if not drop:
				tokens = []
				doc = nlp(line[data_idx].strip())
				for sent in doc.sents:
					for token in sent:
						tokens.append(token.text)

				if tokenize:
					data.append(tokens)
				else:
					data.append(doc.text)
				targets.append(line[target_idx].strip())

	return data, targets


def get_telicity_dataset_ru_multi(dataset_file, data_idx, target_idx, skip_header=True, drop_target=None, delimiter='\t'):
	try:
		nlp = spacy_udpipe.load('ru')
	except Exception as _:
		spacy_udpipe.download('ru')
		nlp = spacy_udpipe.load('ru')

	data = []
	targets = []
	with open(dataset_file, encoding="utf8") as in_file:
		csv_reader = csv.reader(in_file, delimiter=delimiter, quoting=csv.QUOTE_NONE)
		if skip_header: next(in_file)

		for line in csv_reader:
			drop = drop_target == line[target_idx]

			if not drop:
				tokens = []
				doc = nlp(line[data_idx].strip())
				for sent in doc.sents:
					for token in sent:
						tokens.append(token.text)
				data.append(tokens)
				targets.append(line[target_idx].strip())

	return data, targets


def get_telicity_dataset_ru(dataset_file, data_idx, target_idx, skip_header=True, drop_target=None, delimiter=',',
							tokenize=True):
	try:
		nlp = spacy_udpipe.load('ru')
	except Exception as _:
		spacy_udpipe.download('ru')
		nlp = spacy_udpipe.load('ru')

	data = []
	targets = []
	with open(dataset_file, encoding="utf8") as in_file:
		csv_reader = csv.reader(in_file, delimiter=delimiter, quoting=csv.QUOTE_NONE)
		if skip_header: next(in_file)

		for line in csv_reader:
			drop = drop_target == line[target_idx] or line[target_idx] not in ['A', 'S', 'T']
			err = line[target_idx] == 'E'

			if not drop and not err:
				tokens = []
				doc = nlp(line[data_idx].strip())
				for sent in doc.sents:
					for token in sent:
						tokens.append(token.text)

				if tokenize:
					data.append(tokens)
				else:
					data.append(doc.text)
				targets.append(line[target_idx].strip())

	return data, targets


def get_telicity_dataset_ru2(dataset_file, data_idx, target_idx, skip_header=True, drop_target=None, delimiter=','):
	data = []
	targets = []
	with open(dataset_file, encoding="utf8") as in_file:
		csv_reader = csv.reader(in_file, delimiter=delimiter, quoting=csv.QUOTE_NONE)
		if skip_header: next(in_file)

		for line in csv_reader:
			drop = drop_target == line[target_idx]
			err = line[target_idx] == 'E'

			if not drop and not err:
				data.append(line[data_idx].strip().split())
				targets.append(line[target_idx].strip())

	return data, targets


def get_telicity_train_de_test_en_dataset(dataset_file, data_idx, target_idx, skip_header=True):
	data_train = []
	data_test = []
	targets_train = []
	targets_test = []
	with open(dataset_file) as in_file:
		csv_reader = csv.reader(in_file, delimiter='\t', quoting=csv.QUOTE_NONE)
		if skip_header: next(in_file)

		for line in csv_reader:
			if line[target_idx] != 'error':
				if line[-1] == 'de':
					data_train.append(line[data_idx].strip().split())
					targets_train.append(line[target_idx].strip())
				else:
					data_test.append(line[data_idx].strip().split())
					targets_test.append(line[target_idx].strip())

	return (data_train, data_test), (targets_train, targets_test)


def get_telicity_train_en_test_de_dataset(dataset_file, data_idx, target_idx, skip_header=True):
	data_train = []
	data_test = []
	targets_train = []
	targets_test = []
	with open(dataset_file) as in_file:
		csv_reader = csv.reader(in_file, delimiter='\t', quoting=csv.QUOTE_NONE)
		if skip_header: next(in_file)

		for line in csv_reader:
			if line[target_idx] != 'error':
				if line[-1] == 'en':
					data_train.append(line[data_idx].strip().split())
					targets_train.append(line[target_idx].strip())
				else:
					data_test.append(line[data_idx].strip().split())
					targets_test.append(line[target_idx].strip())

	return (data_train, data_test), (targets_train, targets_test)


def get_telicity_train_de_en_test_en_dataset(dataset_file, data_idx, target_idx, skip_header=True):
	cross_val_data = []
	data_additional = []
	targets_cross_val = []
	targets_additional = []
	with open(dataset_file) as in_file:
		csv_reader = csv.reader(in_file, delimiter='\t', quoting=csv.QUOTE_NONE)
		if skip_header: next(in_file)

		for line in csv_reader:
			if line[target_idx] != 'error':
				if line[-1] == 'en':
					cross_val_data.append(line[data_idx].strip().split())
					targets_cross_val.append(line[target_idx].strip())
				else:
					data_additional.append(line[data_idx].strip().split())
					targets_additional.append(line[target_idx].strip())

	return (cross_val_data, data_additional), (targets_cross_val, targets_additional)


def get_telicity_de_en_mixed_dataset(dataset_file, data_idx, target_idx, skip_header=True):
	data = []
	targets = []
	with open(dataset_file) as in_file:
		csv_reader = csv.reader(in_file, delimiter='\t', quoting=csv.QUOTE_NONE)
		if skip_header: next(in_file)

		for line in csv_reader:
			if line[target_idx] != 'error':
				data.append(line[data_idx].strip().split())
				targets.append(line[target_idx].strip())

	return data, targets


def get_domain_dataset_de(dataset_file, data_idx, target_idx, skip_header=True):
	try:
		nlp = spacy_udpipe.load('de')
	except Exception as _:
		spacy_udpipe.download('de')
		nlp = spacy_udpipe.load('de')

	data = []
	targets = []
	with open(dataset_file) as in_file:
		csv_reader = csv.reader(in_file, delimiter='\t', quoting=csv.QUOTE_NONE)
		if skip_header: next(in_file)

		for line in csv_reader:
			if line[target_idx] != 'error':
				doc = nlp(line[data_idx].strip())
				tokens = []
				for sent in doc.sents:
					for token in sent:
						tokens.append(token.text)
				data.append(' '.join(tokens))
				targets.append(line[target_idx].strip())

	return data, targets


def get_domain_dataset_zh(dataset_file, data_idx, target_idx, skip_header=True):
	try:
		nlp = spacy_udpipe.load('zh')
	except Exception as _:
		spacy_udpipe.download('zh')
		nlp = spacy_udpipe.load('zh')

	data = []
	targets = []
	with open(dataset_file) as in_file:
		csv_reader = csv.reader(in_file, quoting=csv.QUOTE_NONE)
		if skip_header: next(in_file)

		for line in csv_reader:
			if line[target_idx] != 'error':
				doc = nlp(line[data_idx].strip())
				tokens = []
				for sent in doc.sents:
					for token in sent:
						tokens.append(token.text)
				data.append(' '.join(tokens))
				targets.append(line[target_idx].strip())

	return data, targets


def get_domain_dataset_ar(dataset_file, data_idx, target_idx, skip_header=True):
	try:
		nlp = spacy_udpipe.load('ar')
	except Exception as _:
		spacy_udpipe.download('ar')
		nlp = spacy_udpipe.load('ar')

	data = []
	targets = []
	with open(dataset_file) as in_file:
		csv_reader = csv.reader(in_file, delimiter='\t', quoting=csv.QUOTE_NONE)
		if skip_header: next(in_file)

		for line in csv_reader:
			if line[target_idx] != 'NA':
				doc = nlp(line[data_idx].strip())
				tokens = []
				for sent in doc.sents:
					for token in sent:
						tokens.append(token.text)
				data.append(' '.join(tokens))
				targets.append(line[target_idx].strip())

	return data, targets


def get_domain_dataset_fa(dataset_file, data_idx, target_idx, skip_header=True):
	try:
		nlp = spacy_udpipe.load('fa')
	except Exception as _:
		spacy_udpipe.download('fa')
		nlp = spacy_udpipe.load('fa')

	data = []
	targets = []
	with open(dataset_file) as in_file:
		csv_reader = csv.reader(in_file, quoting=csv.QUOTE_NONE)
		if skip_header: next(in_file)

		for line in csv_reader:
			if line[target_idx] != 'error':
				doc = nlp(line[data_idx].strip())
				tokens = []
				for sent in doc.sents:
					for token in sent:
						tokens.append(token.text)
				data.append(' '.join(tokens))
				targets.append(line[target_idx].strip())

	return data, targets


def get_domain_dataset_ru(dataset_file, data_idx, target_idx, skip_header=True):
	try:
		nlp = spacy_udpipe.load('ru')
	except Exception as _:
		spacy_udpipe.download('ru')
		nlp = spacy_udpipe.load('ru')

	data = []
	targets = []
	with open(dataset_file) as in_file:
		csv_reader = csv.reader(in_file, quoting=csv.QUOTE_NONE)
		if skip_header: next(in_file)

		for line in csv_reader:
			if line[target_idx] != 'E':
				doc = nlp(line[data_idx].strip())
				tokens = []
				for sent in doc.sents:
					for token in sent:
						tokens.append(token.text)
				data.append(' '.join(tokens))
				targets.append(line[target_idx].strip())

	return data, targets