from helper import *

from proofwriter_classes import PWRuleSelectorInstance, PWFactSelectorInstance, PWReasonerInstance, PWQRuleInstance, PWQFactInstance, PWIterativeInstance
from basemodel import BaseModel
from proofwriter_ruleselector_model import ProofWriterRuleSelector
from proofwriter_factselector_model import ProofWriterFactSelector
from proofwriter_reasoner_model import ProofWriterReasoner
import csv


class ProofWriterInference(BaseModel):

	# counter to count the # times proof generation fails
	count_error_graphs = 0

	# local accounting of proof accuracy
	local_proof_accuracy = []
	local_step = 0

	# maps to store results required for robustness evaluation
	equiv_ans_map, equiv_prf_map = dict(), dict()

	def __init__(self, ruleselector_ckpt, factselector_ckpt, reasoner_ckpt, ques_augmented, arch='', train_batch_size=1, eval_batch_size=1, accumulate_grad_batches=1, learning_rate=1e-5, \
					max_epochs=1, optimizer='adamw', adam_epsilon=1e-8, weight_decay=0.0, lr_scheduler='fixed', warmup_updates=0.0, freeze_epochs=-1, gpus=1, \
					evaluate_pw_iter=False, rule_stopsep=False, rule_stopcls=False, stop_priority='rule', root_dir=None, dumptext=False, eval_pararules=False):
		super().__init__(train_batch_size=train_batch_size, max_epochs=max_epochs, gpus=gpus)
		self.save_hyperparameters()

		self.p                         = types.SimpleNamespace()
		self.p.ques_augmented          = ques_augmented
		self.p.arch                    = arch
		self.p.train_batch_size        = train_batch_size
		self.p.eval_batch_size         = eval_batch_size
		self.p.accumulate_grad_batches = accumulate_grad_batches
		self.p.learning_rate           = learning_rate
		self.p.max_epochs              = max_epochs
		self.p.optimizer               = optimizer
		self.p.adam_epsilon            = adam_epsilon
		self.p.weight_decay            = weight_decay
		self.p.lr_scheduler            = lr_scheduler
		self.p.warmup_updates          = warmup_updates
		self.p.freeze_epochs           = freeze_epochs
		self.p.gpus                    = gpus
		self.p.evaluate_pw_iter        = evaluate_pw_iter		# If true, then we evaluate the ProofWriter Iterative baseline instead of our models
		self.p.rule_stopsep            = rule_stopsep
		self.p.rule_stopcls            = rule_stopcls
		self.p.stop_priority           = stop_priority
		self.p.root_dir				   = root_dir
		self.p.dumptext				   = dumptext
		self.p.eval_pararules          = eval_pararules

		self.dumptext = True

		if not self.p.evaluate_pw_iter:
			self.rule_selector             = ProofWriterRuleSelector().load_from_checkpoint(ruleselector_ckpt)
			self.rule_tokenizer            = self.rule_selector.tokenizer

			self.fact_selector             = ProofWriterFactSelector().load_from_checkpoint(factselector_ckpt)
			self.fact_tokenizer            = self.fact_selector.tokenizer

			self.reasoner                  = ProofWriterReasoner().load_from_checkpoint(reasoner_ckpt)
			self.reasoner_tokenizer        = self.reasoner.tokenizer

			# check stopcls or stopsep is correctly set
			if self.rule_selector.p.stopcls:
				assert self.p.rule_stopcls
			if self.rule_selector.p.stopsep:
				assert self.p.rule_stopsep

		else:
			assert ruleselector_ckpt == ''
			assert factselector_ckpt == ''

			self.reasoner           = T5ForConditionalGeneration.from_pretrained(reasoner_ckpt)
			self.reasoner_tokenizer = T5Tokenizer.from_pretrained("t5-large")
			self.generator_options  = {'min_length': 1, 'max_length': 128, 'num_beams': 1, 'num_return_sequences': 1, 'do_sample': False, 'top_k': 50, 'top_p': 1.0,\
											'temperature': 1.0, 'length_penalty': 1.0, 'repetition_penalty': 1.0}

	def forward(self, batch, ques_augmented=False):
		facts         = batch['all_facts']
		rules         = batch['all_rules']
		ques          = batch['all_ques']
		batch_size    = len(facts)
		proof         = batch['all_proof']
		targets       = batch['all_answer']
		device        = self.reasoner.device
		count         = 0
		stop          = False
		output_dict   = [dict() for _ in range(batch_size)]
		proof_dict    = [ddict(list) for _ in range(batch_size)]
		gold_ans	  = batch['all_answer']
		equiv_id	  = batch['all_equiv_id']
		qdep		  = batch['all_qdep']
		save_errors_list = [[] for i in range(batch_size)]
		# prefill the proof_dict with single triples
		for idx in range(batch_size):
			for fact in facts[idx]:
				proof_dict[idx][fact].append(([fact], ''))		# value format: ([facts], rule)

		if self.p.evaluate_pw_iter:
			valid_mask = [1] * batch_size
			use_lowercase = False

			while not stop:
				input_ids  = PWIterativeInstance.tokenize_batch(self.reasoner_tokenizer, rules, facts, lowercase=use_lowercase)
				output_ids = self.reasoner.generate(torch.LongTensor(input_ids).to(device), **self.generator_options)
				output_str = self.reasoner_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
				for batch_idx in range(batch_size):
					if valid_mask[batch_idx]:
						try:
							# parse the output
							conclusion, proof_text = output_str[batch_idx].replace('$answer$ = ', '').split(' ; $proof$ = ')

							if use_lowercase:
								conclusion = conclusion.lower()
								proof_text = proof_text.lower()

							# check if conclusion is nothing or if the conclusion is already added to facts (case of repeated generation)
							if conclusion == 'nothing.' or conclusion == 'Nothing.' or conclusion in facts[batch_idx]:
								valid_mask[batch_idx] = 0

								if sum(valid_mask) == 0:
									stop = True

								continue

							# figure out the proof rule and facts
							proof_text   = proof_text.replace('# ', '')
							proof_text   = proof_text.replace('& ', '')
							prf_ids      = [int(x.replace('sent', '')) - 1 for x in proof_text.split(' ')]
							prf_rule_id  = prf_ids[0]
							prf_fact_ids = prf_ids[1:]
							context      = rules[batch_idx] + facts[batch_idx]
							prf_rule     = context[prf_rule_id]
							prf_facts    = [context[x] for x in prf_fact_ids]

							# update proof_dict
							proof_dict[batch_idx][conclusion].append((prf_facts, prf_rule))

							# add generated conclusion to previous facts
							facts[batch_idx].append(conclusion)

						except Exception as e:
							# import pdb; pdb.set_trace()
							input_string  = self.reasoner_tokenizer.batch_decode(input_ids, skip_special_tokens=True)[batch_idx]
							output_string = output_str[batch_idx]
							print('Exception Cause: {}'.format(e.args[0]))
							print('Input string: ', input_string)
							print('Output string: ', output_string, '\n')

							# save errors
							if len(save_errors_list[batch_idx]) == 0:
								save_errors_list[batch_idx].append(f'input_string = {input_string}' + '\n' + f'output_string = {output_string}' + f'error = {e.args[0]}' + '\n' + '*'*100 + '\n')

							valid_mask[batch_idx] = 0
							if sum(valid_mask) == 0:
								stop = True



				if sum(valid_mask) == 0:
					stop = True

				count += 1
				if count == 200:
					print('\nStop hit!\n')
					self.count_error_graphs += 1

		else:
			try:
				while not stop:
					# process data for rule selector and select rule
					if ques_augmented:
						input_ids, attn_mask, token_mask = PWQRuleInstance.tokenize_batch(self.rule_tokenizer, rules, facts, ques, self.p.rule_stopsep, self.p.rule_stopcls, not self.p.ques_augmented)
					else:
						input_ids, attn_mask, token_mask = PWRuleSelectorInstance.tokenize_batch(self.rule_tokenizer, rules, facts, self.p.rule_stopsep, self.p.rule_stopcls)
					rule_ids, rule_mask = self.rule_selector.predict(input_ids.to(device), token_mask.to(device), attn_mask.to(device), self.p.stop_priority)

					# loop break condition
					if rule_mask.sum().item() == 0:
						stop = True
						# print('\nAlert: No new rule selected!\n')
						break

					for idx in range(rule_ids.shape[1]):
						selected_rules = [rules[x][y] for x,y in zip(range(batch_size), rule_ids[:, idx])]

						# this will be used to determine which inferences to keep and which ones to reject (batching trick)
						valid_mask     = rule_mask[:, idx]

						# process data for fact selector and select facts for the selected rule
						if ques_augmented:
							input_ids, attn_mask, token_mask = PWQFactInstance.tokenize_batch(self.fact_tokenizer, selected_rules, facts, ques, not self.p.ques_augmented)
						else:
							input_ids, attn_mask, token_mask = PWFactSelectorInstance.tokenize_batch(self.fact_tokenizer, selected_rules, facts)
						fact_ids, fact_mask = self.fact_selector.predict(input_ids.to(device), token_mask.to(device), attn_mask.to(device))

						# update valid_mask to account for cases when no facts are selected (batching trick)
						valid_mask     = valid_mask * fact_mask

						# if nothing is valid then stop
						if valid_mask.sum() == 0:
							stop = True
							# print('\nAlert: No valid fact selected!\n')
							break

						selected_facts = [[facts[x][y] for y in fact_ids[x] if y != -1] for x in range(batch_size)]

						# generate intermediate conclusion
						input_ids  = PWReasonerInstance.tokenize_batch(self.reasoner_tokenizer, selected_rules, selected_facts)
						conclusions = self.reasoner.predict_and_decode(torch.LongTensor(input_ids).to(device))

						new_conc = False	# This flag checks if any new intermediate conclusion was generated in this round for any of the instance in the batch
						for batch_idx in range(batch_size):
							if valid_mask[batch_idx]:
								# add proof to output_dict and increase count
								out_key   = ' '.join(selected_facts[batch_idx]) + '::' + selected_rules[batch_idx] + '::' + conclusions[batch_idx].lower()
								proof_key = conclusions[batch_idx].lower()

								if out_key not in output_dict[batch_idx]:
									new_conc = True
									output_dict[batch_idx][out_key] = 1
									facts[batch_idx].append(conclusions[batch_idx].lower())

									if len(selected_facts[batch_idx]) == 0:
										sys.stdout = sys.__stdout__; import pdb; pdb.set_trace()

									# update proof_dict
									proof_dict[batch_idx][proof_key].append((selected_facts[batch_idx], selected_rules[batch_idx]))
								else:
									output_dict[batch_idx][out_key] += 1

						facts = [list(set(x)) for x in facts]

						# if there are no new conclusions in the batch and all selected rules have been tried, then stop
						if not new_conc and (idx + 1 == rule_ids.shape[1]):
							# print('\nAlert: No new conclusion generated!\n')
							stop = True

						# fail-safe to check for infinite loops cases, if any
						count += 1
						if count == 1000:
							print('Stop hit!')
							sys.stdout = sys.__stdout__; import pdb; pdb.set_trace()

			except Exception as e:
				print('Exception Cause: {}'.format(e.args[0]))
				print(traceback.format_exc())
				sys.stdout = sys.__stdout__; import pdb; pdb.set_trace()

		# solve each instance in batch
		results = []
		for idx in range(batch_size):
			if self.p.eval_pararules:
				ans, prf = self.solver(facts[idx], ques[idx], dict(proof_dict[idx]), gold_proof=proof[idx], gold_ans=targets[idx], eval_pararules=self.p.eval_pararules)
				# ans, prf = self.solver_debug(facts[idx], ques[idx], dict(proof_dict[idx]), proof[idx], targets[idx], eval_pararules=self.p.eval_pararules)
			else:
				ans, prf = self.solver(facts[idx], ques[idx], dict(proof_dict[idx]))
			results.append((ans, prf))

			# our debug
			# results.append(self.solver_debug(facts[idx], ques[idx], dict(proof_dict[idx]), proof[idx], targets[idx]))

			# proofwriter debug
			# results.append(self.solver_debug(facts[idx], ques[idx], dict(proof_dict[idx]), proof[idx], output_str[idx]))

			if self.p.dumptext:
				_, match_proof_element = self.match_proof([prf], [proof[idx]], np.array(gold_ans[idx]==ans))
				try:
					assert len(match_proof_element) == 1
				except:
					self.match_proof([prf], [proof[idx]], np.array(gold_ans[idx])==np.array(ans), debug=True)
					# import pdb; pdb.set_trace()
				match_proof = (match_proof_element[0] == True)

				pathlib.Path(self.p.root_dir).mkdir(exist_ok=True, parents=True)
				with open(os.path.join(self.p.root_dir, 'save.txt'), "a") as f:
					f.write(f'qdep = {str(qdep[idx])}' + '\n' + f'equiv_id = {equiv_id[idx]}' + '\n' + \
							f'ans_solver = {str(ans)}' + '\n' + f'prf_solver = {str(prf)}' + '\n' + f'facts = {str(facts[idx])}' + '\n' + \
							f'ques = {str(ques[idx])}' + '\n' + f'prf_dict = {pformat(dict(proof_dict[idx]))}' + '\n' + \
							f'gold_ans = {str(gold_ans[idx])}' + '\n' + f'gold_proof = {pformat(str(proof[idx]))}' + '\n' +\
							f'Match_ans = {str(gold_ans[idx]==ans)}' + '\n' + '*'*100 + '\n')

				try:

					if (not gold_ans[idx]==ans) or len(save_errors_list[idx])>0:
						with open(os.path.join(self.p.root_dir, 'errors.txt'), "a") as f:
							# if self.p.evaluate_pw_iter and len(save_errors_list[idx])>0:
							if len(save_errors_list[idx])>0:
								f.write(f'qdep = {str(qdep[idx])}' + '\n' + f'equiv_id = {equiv_id[idx]}' + '\n' + \
									f'ans_solver = {str(ans)}' + '\n' + f'prf_solver = {str(prf)}' + '\n' + f'facts = {str(facts[idx])}' + '\n' + \
									f'ques = {str(ques[idx])}' + '\n' + f'prf_dict = {pformat(dict(proof_dict[idx]))}' + '\n' + \
									f'gold_ans = {str(gold_ans[idx])}' + '\n' + f'gold_proof = {pformat(str(proof[idx]))}' + '\n' +\
									f'Match_ans = {str(gold_ans[idx]==ans)}' + '\n' + f'{save_errors_list[idx][0]}' + '\n' + '*'*100 + '\n')
							else:
								f.write(f'qdep = {str(qdep[idx])}' + '\n' + f'equiv_id = {equiv_id[idx]}' + '\n' + \
									f'ans_solver = {str(ans)}' + '\n' + f'prf_solver = {str(prf)}' + '\n' + f'facts = {str(facts[idx])}' + '\n' + \
									f'ques = {str(ques[idx])}' + '\n' + f'prf_dict = {pformat(dict(proof_dict[idx]))}' + '\n' + \
									f'gold_ans = {str(gold_ans[idx])}' + '\n' + f'gold_proof = {pformat(str(proof[idx]))}' + '\n' +\
									f'Match_ans = {str(gold_ans[idx]==ans)}' + '\n' + '*'*100 + '\n')

				except:
					import pdb; pdb.set_trace()

				with open(os.path.join(self.p.root_dir, 'pred.csv'), "a") as csvfile:
					writer = csv.writer(csvfile, delimiter=',')
					writer.writerow([str(qdep[idx]), equiv_id[idx], str(ans), str(prf), str(facts[idx]), str(ques[idx]), str(dict(proof_dict[idx])),\
						str(gold_ans[idx]), str(proof[idx]), str(gold_ans[idx]==ans), str(match_proof)])



		return results

	def solver(self, facts, ques, proof_dict, gold_proof=None, gold_ans=None, eval_pararules=False):
		if eval_pararules:
			# Find if there is a 0-depth proof (a gold proof with length 1 which is not 'None'). Then assume found. This is a hack to bypass matching question
			# to fact which is difficult for ParaRules dataset unless the sentence mapping is used.
			proof_lens = [len(x) for x in gold_proof]
			if 1 in proof_lens:
				check_idx = proof_lens.index(1)
				if gold_proof[check_idx][0] != 'None' and type(gold_proof[check_idx][0]) == str:
					return (gold_ans, [[gold_proof[check_idx][0]]])

		try:
			# check if question is already in facts
			if ques in facts:
				proofs = generate_proof(ques, proof_dict)
				return (1, proofs)
			else:
				# try to negate the ques and see if its present
				ques_neg = negate(ques)
				if ques_neg in facts:
					proofs = generate_proof(ques_neg, proof_dict)
					return (-1, proofs)
				else:
					# no proof exists.
					return (0, [['None']])
		except Exception as e:
			self.count_error_graphs += 1
			print('Graph error count: ', self.count_error_graphs)
			return (0, [['None']])

	def solver_debug(self, facts, ques, proof_dict, gold_proof, gold_ans, last_str=None, eval_pararules=False):
		print(ques)
		print(gold_proof)
		if eval_pararules:
			# Find if there is a 0-depth proof (a gold proof with length 1 which is not 'None'). Then assume found. This is a hack to bypass matching question
			# to fact which is difficult for ParaRules dataset unless the sentence mapping is used.
			proof_lens = [len(x) for x in gold_proof]
			if 1 in proof_lens:
				check_idx = proof_lens.index(1)
				if gold_proof[check_idx][0] != 'None' and type(gold_proof[check_idx][0]) == str:
					sys.stdout = sys.__stdout__; import pdb; pdb.set_trace()
					return (gold_ans, [[gold_proof[check_idx][0]]])
		try:
			counter_gold_proof = [Counter(x) for x in gold_proof]

			counter_gold_proof = [Counter({y:1 for y in x}) for x in counter_gold_proof]

			# check if question is already in facts
			if ques in facts:
				# print('\n*******************')
				proofs = generate_proof(ques, proof_dict)
				# print('*******************\n')
				if None in proofs:
					sys.stdout = sys.__stdout__; import pdb; pdb.set_trace()
				found = False
				proof = ''
				for prf in proofs:
					if Counter({y:1 for y in Counter(prf)}) in counter_gold_proof:
						found = True
						proof = prf
						break
				if not found:
					sys.stdout = sys.__stdout__; import pdb; pdb.set_trace()
				return (1, proofs)
			else:
				# try to negate the ques and see if its present
				ques_neg = negate(ques)
				if ques_neg in facts:
					# print('\n*******************')
					proofs = generate_proof(ques_neg, proof_dict)
					# print('*******************\n')
					if None in proofs:
						sys.stdout = sys.__stdout__; import pdb; pdb.set_trace()
					found = False
					proof = ''
					for prf in proofs:
						if Counter({y:1 for y in Counter(prf)}) in counter_gold_proof:
							found = True
							proof = prf
							break
					if not found:
						sys.stdout = sys.__stdout__; import pdb; pdb.set_trace()
					return (-1, proofs)
				else:
					found = False
					proof = ''
					for prf in [['None']]:
						if Counter({y:1 for y in Counter(prf)}) in counter_gold_proof:
							found = True
							proof = prf
							break
					if not found:
						sys.stdout = sys.__stdout__; import pdb; pdb.set_trace()

					# no proof exists. TODO: Change this for CWA case
					return (0, [['None']])
		except Exception as e:
			self.count_error_graphs += 1
			print('Graph error count: ', self.count_error_graphs)
			sys.stdout = sys.__stdout__; import pdb; pdb.set_trace()
			return (0, [['None']])
			# print('Exception Cause: {}'.format(e.args[0]))
			# import pdb; pdb.set_trace()

	def calc_acc(self, preds, targets):
		matched = np.array(preds) == np.array(targets)
		return 100 * np.mean(matched), matched

	def match_proof(self, all_proofs, all_gold_proofs, ans_match, debug = False):
		if debug:
			import pdb; pdb.set_trace()
		res = []
		for idx in range(len(all_proofs)):
			proofs = all_proofs[idx]
			gold_proofs = all_gold_proofs[idx]

			gold_proofs_counter = [Counter(x) for x in gold_proofs]
			gold_proofs_counter = [Counter({y:1 for y in x}) for x in gold_proofs_counter]

			found = False
			for prf in proofs:
				if Counter({y:1 for y in Counter(prf)}) in gold_proofs_counter:
					found = True
					break

			res.append(found)

		final_res = res * ans_match

		return 100 * np.mean(final_res), final_res

	def run_step(self, batch, split):
		out         = self(batch, ques_augmented=self.p.ques_augmented)
		targets     = batch['all_answer']
		gold_proofs = batch['all_proof']

		# calculate question entailment accuracy
		preds              = [x[0] for x in out]
		ans_acc, ans_match = self.calc_acc(preds, targets)
		ans_acc            = torch.FloatTensor([ans_acc]).to(self.reasoner.device)

		# calculate proof match accuracy
		proofs             = [x[1] for x in out]
		prf_acc, prf_match = self.match_proof(proofs, gold_proofs, ans_match)
		self.local_proof_accuracy.append(prf_acc)
		prf_acc = torch.FloatTensor([prf_acc]).to(self.reasoner.device)

		self.local_step += 1
		if self.local_step % 20 == 0:
			print(f'\nProof Accuracy: {np.mean(self.local_proof_accuracy)}\n')

		# update the dictionaries for robustness evaluation
		for k,v in zip(batch['all_equiv_id'], ans_match):
			self.equiv_ans_map[k] = v
		for k,v in zip(batch['all_equiv_id'], prf_match):
			self.equiv_prf_map[k] = v

		if split == 'train':
			self.log(f'train_ans_acc_step', ans_acc, prog_bar=True)
			self.log(f'train_prf_acc_step', prf_acc, prog_bar=True)
		else:
			self.log(f'{split}_ans_acc_step', ans_acc, prog_bar=True, sync_dist=True)
			self.log(f'{split}_prf_acc_step', prf_acc, prog_bar=True, sync_dist=True)



		return {'ans_acc': ans_acc, 'prf_acc': prf_acc, 'loss': torch.FloatTensor([0]).to(self.reasoner.device)}

	def aggregate_epoch(self, outputs, split):
		ans_acc = torch.stack([x['ans_acc'] for x in outputs]).mean()
		prf_acc = torch.stack([x['prf_acc'] for x in outputs]).mean()

		# for robustness dataset, calculate the consistency metric
		ans_consistency_map, prf_consistency_map, ans_prf_consistency_map = ddict(set), ddict(set), ddict(set)
		for k,v in self.equiv_ans_map.items():
			row_id = k.split('_')[0]
			ans_consistency_map[row_id].add(v)
		for k,v in self.equiv_prf_map.items():
			row_id = k.split('_')[0]
			prf_consistency_map[row_id].add(v)

		assert len(self.equiv_ans_map) == len(self.equiv_prf_map)
		for k in self.equiv_ans_map.keys():
			row_id = k.split('_')[0]
			v1 = self.equiv_ans_map[k]
			v2 = self.equiv_prf_map[k]
			ans_prf_consistency_map[row_id].add((v1, v2))

		ans_consistncy     = 100 * (1 - len([1 for v in ans_consistency_map.values() if len(v) > 1]) / len(ans_consistency_map))
		prf_consistncy     = 100 * (1 - len([1 for v in prf_consistency_map.values() if len(v) > 1]) / len(prf_consistency_map))
		ans_prf_consistncy = 100 * (1 - len([1 for v in ans_prf_consistency_map.values() if len(v) > 1]) / len(ans_prf_consistency_map))

		if split == 'train':
			self.log(f'train_ans_acc_epoch', ans_acc.item())
			self.log(f'train_prf_acc_epoch', prf_acc.item())
		else:
			self.log(f'{split}_ans_acc_epoch', ans_acc.item(), sync_dist=True)
			self.log(f'{split}_prf_acc_epoch', prf_acc.item(), sync_dist=True)
			self.log(f'{split}_ans_consistency', ans_consistncy, sync_dist=True)
			self.log(f'{split}_prf_consistency', prf_consistncy, sync_dist=True)
			self.log(f'{split}_ans_prf_consistency', ans_prf_consistncy, sync_dist=True)
			self.log(f'Graph Errors: ', self.count_error_graphs, sync_dist=True)

	def configure_optimizers(self):
		no_decay = ['bias', 'LayerNorm.weight']
		optimizer_grouped_parameters = [
			{
				'params'      : [p for n, p in self.rule_selector.named_parameters() if not any(nd in n for nd in no_decay)],
				'weight_decay': self.p.weight_decay,
			},
			{
				'params'      : [p for n, p in self.rule_selector.named_parameters() if any(nd in n for nd in no_decay)],
				'weight_decay': 0.0,
			}
		]

		optimizer_grouped_parameters += [
			{
				'params'      : [p for n, p in self.fact_selector.named_parameters() if not any(nd in n for nd in no_decay)],
				'weight_decay': self.p.weight_decay,
			},
			{
				'params'      : [p for n, p in self.fact_selector.named_parameters() if any(nd in n for nd in no_decay)],
				'weight_decay': 0.0,
			}
		]

		optimizer_grouped_parameters += [
			{
				'params'      : [p for n, p in self.reasoner.named_parameters() if not any(nd in n for nd in no_decay)],
				'weight_decay': self.p.weight_decay,
			},
			{
				'params'      : [p for n, p in self.reasoner.named_parameters() if any(nd in n for nd in no_decay)],
				'weight_decay': 0.0,
			}
		]

		if self.p.optimizer == 'adamw':
			optimizer = AdamW(optimizer_grouped_parameters, lr=self.p.learning_rate)
		else:
			raise NotImplementedError

		if self.p.lr_scheduler == 'linear_with_warmup':
			if self.p.warmup_updates > 1.0:
				warmup_steps = int(self.p.warmup_updates)
			else:
				warmup_steps = int(self.total_steps * self.p.warmup_updates)
			print(f'\nTotal steps: {self.total_steps} with warmup steps: {warmup_steps}\n')

			scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=warmup_steps, num_training_steps=self.total_steps)
			scheduler = {
				'scheduler': scheduler,
				'interval': 'step',
				'frequency': 1
			}
		elif self.p.lr_scheduler == 'fixed':
			return [optimizer]
		else:
			raise NotImplementedError

		return [optimizer], [scheduler]


class InfiniteRecursionError(OverflowError):
    '''raise this when there's an infinite recursion possibility in proof generation'''

def get_verb(sent):
	if ' visits ' in sent:
		return 'visits'
	elif ' sees ' in sent:
		return 'sees'
	elif ' likes ' in sent:
		return 'likes'
	elif ' eats ' in sent:
		return 'eats'
	elif ' chases ' in sent:
		return 'chases'
	elif ' needs ' in sent:
		return 'needs'
	elif ' wants ' in sent:
		return 'wants'
	elif ' forgets ' in sent:
		return 'forgets'
	elif ' humiliates ' in sent:
		return 'humiliates'
	elif ' treats ' in sent:
		return 'treats'
	elif ' serves ' in sent:
		return 'serves'
	elif ' abandons ' in sent:
		return 'abandons'
	elif ' hates ' in sent:
		return 'hates'
	elif ' loves ' in sent:
		return 'loves'
	elif ' kills ' in sent:
		return 'kills'
	elif ' doubts ' in sent:
		return 'doubts'
	elif ' runs ' in sent:
		return 'runs'

	sys.stdout = sys.__stdout__; import pdb; pdb.set_trace()

def negate(sent):
	'''Generate the negation of a sentence using simple regex'''
	if ' is not ' in sent:
		# is not --> is
		sent = sent.replace('is not', 'is')
	elif ' is ' in sent:
		# is --> is not
		sent = sent.replace('is', 'is not')
	elif ' does not ' in sent:
		# does not visit --> visits
		# find the next word in the sentence after not, i.e., "... does not X ..."
		all_words = sent.split()
		next_word = all_words[all_words.index('not') + 1]
		new_word  = next_word + 's'
		sent      = sent.replace(f'does not {next_word}', new_word)
	else:
		# visits --> does not visit
		verb = get_verb(sent)
		new_verb = verb[:-1] # removes the s in the last place TODO: Fix this later using https://pypi.org/project/inflect/
		sent = sent.replace(verb, f'does not {new_verb}')

	return sent

def generate_proof(last_fact, proof_dict, last_rule=None):
	all_proofs = []
	# print('Last fact: ', last_fact)
	for idx in range(len(proof_dict[last_fact])):
		facts, rule = proof_dict[last_fact][idx]
		# print('Start: ', idx, facts, rule)

		# hack to handle an infinite recursion issue - this can happen if the last_fact equals one of the facts in the proof
		if last_fact in facts and rule != '':
			# If rule is equal to '' then it's expected to contain last_fact by design
			raise InfiniteRecursionError('Cycle in proof graph!')

		if rule == '':
			# print('No rule!')
			assert len(facts) == 1
			if last_rule is None:
				# print('None, Returning: ', f'[[({facts[0]})]]')
				return [[(facts[0])]]
			else:
				# print('Rule, Returning: ', f'[[({facts[0]}, {last_rule})]]')
				return [[(facts[0], last_rule)]]

		else:
			if len(facts) == 1:
				# print('L1: ', facts[0])
				proofs = generate_proof(facts[0], proof_dict, rule)
				if last_rule is not None:
					_ = [x.append((rule, last_rule)) for x in proofs]
				# print('L1, Adding: ', proofs)
				all_proofs.extend(proofs)

			elif len(facts) >= 2:
				# print('L>1: ', facts)
				intermediate_proofs = [generate_proof(facts[fact_idx], proof_dict, rule) for fact_idx in range(len(facts))]
				permuted = list(itertools.product(*intermediate_proofs))
				permuted = [list(itertools.chain.from_iterable(x)) for x in permuted]
				if last_rule is not None:
					_ = [x.append((rule, last_rule)) for x in permuted]
				# print('L>1, Adding: ', permuted)
				all_proofs.extend(permuted)

	return all_proofs
