from helper import *
from pw_helper import *


class PWRuleSelectorInstance:

	def __init__(self, instance_para, labels):
		super().__init__()
		self.instance_para = instance_para
		self.labels        = labels # list of length (number of rules), 1 if rule is selected else 0

	@classmethod
	def get_rule_from_inference(cls, proofs):
		'''
		selects the rule from all the proofs given to us as a list (proofs is a list)
		returns the rule in string format as well as rule number ie rule1, rule2 etc
		the rule selected is the one which is used to derive the shortest onestep inference
		input: proofs = [(facts, fact_ids, rule, rule_id), (facts, fact_ids, rule, rule_id) ... ]
						the list is comprised of all the proofs, for one particular conclusion/inference
		'''
		min_len  = 1000 #random large number
		rulenum  = ''
		ruletext = ''

		for proof in proofs:
			facts, fact_ids, rule, rule_id = proof
			assert len(facts) == len(fact_ids)
			numfacts = len(facts)
			if(numfacts<min_len): # select the smallest proof among all the one step proofs. length of proof ~ no.of facts used by the proof
				rulenum, ruletext = rule_id, rule
				min_len = numfacts

		return rulenum, ruletext

	@classmethod
	def select_rules(cls, json_dict):
		'''
		returns the rules used to infer the one step inferred statemets
		returns: list of the form [(rule1, rule1text), (rule3, rule3text), ...]
		'''
		selected_rules = set() # so that the rules aren't repeated
		inferences = parse_all_inferences(json_dict, return_text=True, pwq = False) # inferences is of the form [([(facts, fact_ids, rule, rule_id), (facts, fact_ids, rule, rule_id) ... ], conclusion), .....]
		for i, inference in enumerate(inferences):
			proofs, conclusion = inference
			rulenum, ruletext  = cls.get_rule_from_inference(proofs) #rulenum is for eg rule1, rule2 etc
			selected_rules.add((rulenum, ruletext))

		return list(selected_rules)

	@classmethod
	def from_json(cls, json_dict, tokenizer):
		'''
		creates a number of data points from the json_dict,
		i.e., whether a rule should be selected or not
		'''

		facts_para, _, _         	= get_facts(json_dict)
		_, _, _, _, all_rule_list   = get_rules(json_dict)
		selected_rule_list = cls.select_rules(json_dict)

		instance_para = facts_para
		labels = [] #list of 0/1, of size = number of rules, if a rule is selected then 1 else 0

		for i, ruletuple in enumerate(all_rule_list):
			instance_para = instance_para + tokenizer.sep_token + ruletuple[1]
			if(ruletuple in selected_rule_list):
				labels.append(1)
			else:
				labels.append(0)

		return PWRuleSelectorInstance(instance_para, labels)

	def tokenize_ptlm(self, tokenizer):
		# convert the data in the format expected by the PTLM
		# format: [CLS] factspara [SEP] rule1text [SEP] rule2text [SEP].....[SEP]

		input_tokens = tokenizer.cls_token + self.instance_para + tokenizer.sep_token
		input_tokens_tokenized = tokenizer.tokenize(input_tokens)
		input_ids = tokenizer.convert_tokens_to_ids(input_tokens_tokenized)
		token_mask = [1 if token == tokenizer.sep_token else 0 for token in input_tokens_tokenized] # list of 0s and 1s with 1s at positions of all sep tokens.
		token_mask[-1] = 0 # since the last sep token doesnot correspond to any rule
		sep_token_indices = [i for i in range(len(token_mask)) if token_mask[i] == 1]
		token_labels = np.zeros(len(token_mask))
		assert len(self.labels) == len(sep_token_indices)
		token_labels[sep_token_indices] = self.labels

		return input_ids, token_labels.tolist(), token_mask

	def tokenize(self, tokenizer, arch, split):
		if arch == 'roberta_large_race' or arch == 'roberta_large':
			return self.tokenize_ptlm(tokenizer)
		else:
			raise NotImplementedError

	@classmethod
	def tokenize_instance(cls, tokenizer, rules, facts):
		rules         = list(map(str.lower, rules))
		facts         = list(map(str.lower, facts))
		input_tokens  = tokenizer.cls_token + ' '.join(facts) + tokenizer.sep_token + tokenizer.sep_token.join(rules) + tokenizer.sep_token
		input_ids     = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(input_tokens, truncation=True))

		token_mask     = [1 if token == tokenizer.sep_token_id else 0 for token in input_ids]
		token_mask[-1] = 0

		return input_ids, token_mask

	@classmethod
	def tokenize_batch(cls, tokenizer, batched_rules, batched_facts, stopsep, stopcls):
		new_rules        = [map(str.lower, rules) for rules in batched_rules]
		new_facts        = [map(str.lower, facts) for facts in batched_facts]
		input_tokens     = [' '.join(facts) + tokenizer.sep_token + tokenizer.sep_token.join(rules) for facts,rules in zip(new_facts, new_rules)]
		tokenized        = tokenizer(input_tokens, add_special_tokens=True, padding=True, truncation=True, return_tensors='pt', return_special_tokens_mask=True)
		input_ids        = tokenized['input_ids']
		attn_mask        = tokenized['attention_mask']
		token_mask       = (input_ids == tokenizer.sep_token_id)
		sep_mask         = tokenized['special_tokens_mask'] * token_mask

		if not stopsep:
			token_mask[sep_mask.bool()] = False

		if stopcls:
			token_mask[:, 0] = 1

		return input_ids, attn_mask, token_mask


class PWFactSelectorInstance:

	def __init__(self, instance_para, labels):
		super().__init__()
		self.instance_para	= instance_para
		self.labels		= labels # list of length (number of rules), 1 if rule is selected else 0

	@classmethod
	def get_rule_fact_from_inference(cls, proofs):
		'''
		parses the inference which is in json format to select the rule
		returns the rule in string format as well as rule numeber ie rule1, rule2 etc
		the rule selected is the one which is used to derive the shortest onestep inference
		'''
		min_len = 1000 #random large number

		for proof in proofs:
			facts, fact_ids, rule, rule_id = proof
			assert len(facts) == len(fact_ids)
			numfacts = len(facts)
			if(numfacts<min_len): # select the smallest proof among all the one step proofs. length of proof ~ no.of facts used by the proof
				rulenum ,ruletext = rule_id, rule
				factnum_list, facttext_list  = fact_ids, facts
				min_len = numfacts

		return rulenum, ruletext, factnum_list, facttext_list

	@classmethod
	def get_rule_fact_dict(cls, json_dict):
		'''
		returns dict of the form {rule2:[[fact1, fact2], [fact3]], rule 3:[[fact4],[fact2,  fact1]]}  where the rules in the keys are used for inferring some conclusion
		'''
		rulenum_factnum_dict = {}
		inferences = parse_all_inferences(json_dict, return_text=True, pwq = False) # inferences is of the form [([(facts, fact_ids, rule, rule_id), (facts, fact_ids, rule, rule_id) ... ], conclusion), .....]
		for i, inference in enumerate(inferences):
		# for i, inference in enumerate(json_dict['allInferences']):
			proofs, conclusion = inference
			rulenum, ruletext, factnum_list, facttext_list = cls.get_rule_fact_from_inference(proofs) # rulenum is for eg rule1, rule2 etc, factnum is for eg fact1, fact2 etc.

			if(rulenum in rulenum_factnum_dict.keys()):
				rulenum_factnum_dict[rulenum].append(factnum_list)
			else:
				rulenum_factnum_dict[rulenum] = [factnum_list]


		return rulenum_factnum_dict

	@classmethod
	def get_instance_para_labels(cls, factnumlist_selected, ruletext, allfacts_list, allfactsnum_list, tokenizer):

		instance_para = ruletext
		labels = []
		for i, factnum in enumerate(allfactsnum_list):
			instance_para = instance_para + tokenizer.sep_token + allfacts_list[i]
			if factnum in factnumlist_selected:
				labels.append(1)
			else:
				labels.append(0)
		return instance_para, labels


	@classmethod
	def from_json(cls, json_dict_1, json_dict_2, tokenizer):
		'''
		creates a number of data points from the json_dict_1,
		ie which facts should be selected based on a given rule.
		'''

		facts_para1, facts_list1, factsnum_list1 = get_facts(json_dict_1)
		rulenum_factnum_dict1 = cls.get_rule_fact_dict(json_dict_1) # dict of the form {rule2:[[fact1, fact2], [fact3]], rule 3:[[fact4],[fact2,  fact1]]} where the rules in the keys are used for inferring some conclusion
		_, _, _, all_rules_dict1, _ = get_rules(json_dict_1)

		facts_para2, facts_list2, factsnum_list2 = get_facts(json_dict_2)
		rulenum_factnum_dict2 = cls.get_rule_fact_dict(json_dict_2)

		instances = [] # there will be multiple data points from one json line in the file, hence a list

		for rulenum in rulenum_factnum_dict1.keys():
			# following is an assert statement, which can be used to check of the form instances where {rule2:[[fact1, fact2], [fact3]]} are there in the data
			# assert len(set([len(x) for x in rulenum_factnum_dict1[rulenum]])) == 1
			# print(all_rules_dict1, rulenum_factnum_dict1.keys())
			ruletext = all_rules_dict1[rulenum]

			if(len(rulenum_factnum_dict1[rulenum]) == 1):
				# rule only appears once in the inference hence we can select it safely
				instance_para, labels = cls.get_instance_para_labels(rulenum_factnum_dict1[rulenum][0], ruletext, facts_list1, factsnum_list1, tokenizer)
				instances.append(PWFactSelectorInstance(instance_para, labels))
			else:
				# rule appears more than once in the inferences
				len(rulenum_factnum_dict1[rulenum]) > 1 # to make sure that len(rulenum_factnum_dict1[rulenum])!=0 ever
				found = False
				for factlist in rulenum_factnum_dict1[rulenum]:
					if factlist not in rulenum_factnum_dict2[rulenum]:
						# we can select this factlist as its in the first json line but not in the next json line
						instance_para, labels = cls.get_instance_para_labels(factlist, ruletext, facts_list1, factsnum_list1, tokenizer)
						instances.append(PWFactSelectorInstance(instance_para, labels))
						found = True
						break
				if found==False:
					#select any factlist randomly from the list of that rulenum
					num_factlists = len(rulenum_factnum_dict1[rulenum])
					select_index = random.randint(0, num_factlists-1)
					factlist_selected = rulenum_factnum_dict1[rulenum][select_index]
					instance_para, labels = cls.get_instance_para_labels(factlist_selected, ruletext, facts_list1, factsnum_list1, tokenizer)
					instances.append(PWFactSelectorInstance(instance_para, labels))

		return instances

	def tokenize_ptlm(self, tokenizer):
		# convert the data in the format expected by the PTLM
		# format: [CLS] rule [SEP] fact1text [SEP] fact2text [SEP].....[SEP]

		input_tokens = tokenizer.cls_token + self.instance_para + tokenizer.sep_token
		input_tokens_tokenized = tokenizer.tokenize(input_tokens)
		input_ids    = tokenizer.convert_tokens_to_ids(input_tokens_tokenized)
		token_mask = [1 if token == tokenizer.sep_token else 0 for token in input_tokens_tokenized] # list of 0s and 1s with 1s at positions of all sep tokens.
		token_mask[-1] = 0 # since the last sep token doesnot correspond to any rule
		sep_token_indices = [i for i in range(len(token_mask)) if token_mask[i] == 1]
		token_labels = np.zeros(len(token_mask))
		assert len(self.labels) == len(sep_token_indices)
		token_labels[sep_token_indices] = self.labels

		return input_ids, token_labels.tolist(), token_mask


	def tokenize(self, tokenizer, arch, split):
		if arch == 'roberta_large_race' or arch == 'roberta_large':
			return self.tokenize_ptlm(tokenizer)
		else:
			raise NotImplementedError

	@classmethod
	def tokenize_instance(cls, tokenizer, rule, facts):
		facts         = list(map(str.lower, facts))
		input_tokens  = tokenizer.cls_token + rule.lower() + tokenizer.sep_token + tokenizer.sep_token.join(facts) + tokenizer.sep_token
		input_ids     = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(input_tokens, truncation=True))

		token_mask     = [1 if token == tokenizer.sep_token_id else 0 for token in input_ids]
		token_mask[-1] = 0

		return input_ids, token_mask

	@classmethod
	def tokenize_batch(cls, tokenizer, batched_rules, batched_facts):
		new_rules        = [rule.lower() for rule in batched_rules]
		new_facts        = [map(str.lower, facts) for facts in batched_facts]
		input_tokens     = [rule + tokenizer.sep_token + tokenizer.sep_token.join(facts) for facts,rule in zip(new_facts, new_rules)]
		tokenized        = tokenizer(input_tokens, add_special_tokens=True, padding=True, truncation=True, return_tensors='pt', return_special_tokens_mask=True)
		input_ids        = tokenized['input_ids']
		attn_mask        = tokenized['attention_mask']
		token_mask       = (input_ids == tokenizer.sep_token_id)
		sep_mask         = tokenized['special_tokens_mask'] * token_mask
		token_mask[sep_mask.bool()] = False

		return input_ids, attn_mask, token_mask


class PWReasonerInstance:

	def __init__(self, rule, facts, conclusion):
		self.rule  		= rule # selected from rule selector
		self.facts  	= facts # selected from fact selector
		self.conclusion = conclusion

	@classmethod
	def from_json(cls, json_dict):
		'''
		creates a input output pair, where input contains facts + rules and output contains the generated conclusion
		'''
		all_inferences = parse_all_inferences(json_dict, return_text=True, pwq = False)

		instances = []
		for inference in all_inferences:
			proofs, conclusion = inference
			for proof in proofs:
				facts, _, rule, _ = proof
				instances.append(PWReasonerInstance(rule, facts, conclusion))

		return instances

	def tokenize_ptlm(self, tokenizer):
		# convert the data in the format expected by the PTLM
		# input format: facts rule </s>
		# output format: <pad> conclusion </s>

		input_tokens  = format_facts(self.facts) + self.rule + tokenizer.eos_token
		input_ids     = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(input_tokens))
		output_tokens = tokenizer.pad_token + self.conclusion + tokenizer.eos_token
		output_ids    = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(output_tokens))

		return input_ids, output_ids

	def tokenize(self, tokenizer, arch, split):
		if arch == 't5_base' or arch == 't5_large':
			return self.tokenize_ptlm(tokenizer)
		else:
			raise NotImplementedError

	@classmethod
	def tokenize_instance(cls, tokenizer, rule, facts):
		input_tokens  = ' '.join(facts) + rule + tokenizer.eos_token
		input_ids     = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(input_tokens, truncation=True))

		return input_ids

	@classmethod
	def tokenize_batch(cls, tokenizer, batched_rules, batched_facts):
		new_rules        = [rule.lower() for rule in batched_rules]
		new_facts        = [list(map(str.lower, facts)) if len(facts) > 0 else [] for facts in batched_facts]
		input_tokens     = [(' '.join(facts) if len(facts) > 0 else '') + rule + tokenizer.eos_token for facts,rule in zip(new_facts, new_rules)]
		tokenized        = tokenizer(input_tokens, add_special_tokens=False, padding=True, truncation=True, return_tensors='pt')
		input_ids        = tokenized['input_ids']

		return input_ids


class PWIterativeInstance:

	def __init__(self, inputs, conclusion, prf_facts, prf_rule, rules=None, facts=None):
		super().__init__()
		self.inputs     = inputs
		self.all_rules  = rules
		self.all_facts  = facts
		self.conclusion = conclusion
		self.prf_facts  = prf_facts
		self.prf_rule   = prf_rule

	@classmethod
	def from_json(cls, json_dict1, json_dict2, tokenizer):
		all_facts   = get_facts(json_dict1, lowercase=False)[1]
		all_rules   = get_rules(json_dict1, lowercase=False)[1]
		inferences1 = parse_all_inferences(json_dict1, return_text=True, pwq=True, lowercase=False, take_first_proof=False)
		if json_dict2 is None:
			# last row of staged file
			conclusion = 'Nothing.'
			return PWIterativeInstance(all_rules + all_facts, conclusion, [['None']], ['None'], rules=all_rules, facts=all_facts)
		else:
			# intermediate row of staged file
			inferences2 = parse_all_inferences(json_dict2, return_text=True, pwq=True, lowercase=False)
			unique_key  = set(inferences1.keys()) - set(inferences2.keys())
			assert len(unique_key) == 1
			conclusion  = unique_key.pop()
			return PWIterativeInstance(all_rules + all_facts, conclusion, [x[0] for x in inferences1[conclusion]], [x[2] for x in inferences1[conclusion]], rules=all_rules, facts=all_facts)

	@classmethod
	def from_json_multiple(cls, json_dict1, json_dict2, tokenizer):
		all_facts   = get_facts(json_dict1, lowercase=False)[1]
		all_rules   = get_rules(json_dict1, lowercase=False)[1]
		inferences1 = parse_all_inferences(json_dict1, return_text=True, pwq=True, lowercase=False, take_first_proof=False)

		if json_dict2 is None:
			# last row of staged file
			conclusion = ['Nothing.']
			return PWIterativeInstance(all_rules + all_facts, conclusion, [['None']], ['None'], rules=all_rules, facts=all_facts)
		else:
			all_concs, proof_rules, proof_facts = [], [], []
			for conc, proofs in inferences1.items():
				for proof in proofs:
					all_concs.append(conc)
					proof_rules.append(proof[2])
					proof_facts.append(proof[0])

			return PWIterativeInstance(all_rules + all_facts, all_concs, proof_facts, proof_rules, rules=all_rules, facts=all_facts)

	def join_facts(self, prf_facts, sent_id_map):
		if len(prf_facts) == 1:
			return sent_id_map[prf_facts[0]]
		else:
			return '& ' + " ".join([sent_id_map[x] for x in prf_facts])

	def tokenize_ptlm_multiple(self, tokenizer, delimiter='@@'):
		# convert the data in the format expected by the PTLM
		# input format: $answer$ ; $proof$ ; $question$ = What is one single-hop inference? ; $context$ = sent1: rule1. sent2: rule2. sent3: fact1. sent4: fact4.
		# output format: $answer$ = conc1. ; $proof$ = # sent2 sent12

		input_str   = '$answer$ ; $proof$ ; $question$ = What is one single-hop inference? ; $context$ ='
		ctx         = ''
		sent_id_map = {}
		count       = 1
		for idx, inp in enumerate(self.inputs):
			sent_id_map[inp] = f'sent{idx+1}'
			ctx += f' sent{idx+1}: {inp}'
		input_str  += ctx
		input_str  += tokenizer.eos_token

		if self.conclusion == ['Nothing.']:
			output_str = tokenizer.pad_token + f'$answer$ = Nothing. ; $proof$ = None' + tokenizer.eos_token
		else:
			output_list = [tokenizer.pad_token + \
				f'$answer$ = {conc} ; $proof$ = # {sent_id_map[self.prf_rule[conc_idx]]} {self.join_facts(self.prf_facts[conc_idx], sent_id_map)}' + \
				tokenizer.eos_token for conc_idx, conc in enumerate(self.conclusion)]
			output_str = delimiter.join(output_list) 	# To make sense of this check proofwriter_iterative_model.py

		input_ids  = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(input_str))[:tokenizer.model_max_length]
		output_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(output_str))[:tokenizer.model_max_length]

		return input_ids, output_ids

	def tokenize_ptlm(self, tokenizer, delete=None):
		# convert the data in the format expected by the PTLM
		# input format: $answer$ ; $proof$ ; $question$ = What is one single-hop inference? ; $context$ = sent1: rule1. sent2: rule2. sent3: fact1. sent4: fact4.
		# output format: $answer$ = conc1. ; $proof$ = # sent2 sent12
		# delete='del' --> delete all rules + facts that are not required for making the conclusion (except for case when conclusion is Nothing since we probably need all inputs for that)

		if delete is not None:
			assert delete == 'del'

		input_list, output_list = [], []
		for prf_idx in range(len(self.prf_rule)):
			proof_rule  = self.prf_rule[prf_idx]
			proof_facts = self.prf_facts[prf_idx]
			input_str   = '$answer$ ; $proof$ ; $question$ = What is one single-hop inference? ; $context$ ='
			ctx         = ''
			sent_id_map = {}
			count       = 1
			for idx, inp in enumerate(self.inputs):
				if delete == 'del' and self.conclusion != 'Nothing.':
					if inp == proof_rule or inp in proof_facts:
						sent_id_map[inp] = f'sent{count}'
						ctx += f' sent{count}: {inp}'
						count += 1
				else:
					sent_id_map[inp] = f'sent{idx+1}'
					ctx += f' sent{idx+1}: {inp}'
			input_str  += ctx
			input_str  += tokenizer.eos_token
			input_list.append(input_str)

			if self.conclusion == 'Nothing.':
				output_str = tokenizer.pad_token + f'$answer$ = Nothing. ; $proof$ = None' + tokenizer.eos_token
			else:
				output_str = tokenizer.pad_token + \
					f'$answer$ = {self.conclusion} ; $proof$ = # {sent_id_map[proof_rule]} {self.join_facts(proof_facts, sent_id_map)}' + \
					tokenizer.eos_token
			output_list.append(output_str)

		input_ids  = [tokenizer.convert_tokens_to_ids(tokenizer.tokenize(input_s))[:tokenizer.model_max_length] for input_s in input_list]
		output_ids = [tokenizer.convert_tokens_to_ids(tokenizer.tokenize(output_s))[:tokenizer.model_max_length] for output_s in output_list]

		return input_ids, output_ids

	def tokenize_ptlm_del(self, tokenizer, delete=None, delimiter='@@'):
		# convert the data in the format expected by the PTLM
		# input format: $answer$ ; $proof$ ; $question$ = What is one single-hop inference? ; $context$ = sent1: rule1. sent2: rule2. sent3: fact1. sent4: fact4.
		# output format: $answer$ = conc1. ; $proof$ = # sent2 sent12
		# delete='delrule' --> delete all rules that are not required for making the conclusion (except for case when conclusion is Nothing since we probably need all inputs for that)
		# delete='delfacts' --> delete all facts that are not required for making the conclusion (except for case when conclusion is Nothing since we probably need all inputs for that)

		if delete is not None:
			assert delete == 'delrule' or delete == 'delfact'

		input_list, output_list = [], []
		for prf_idx in range(len(self.prf_rule)):
			proof_rule  = self.prf_rule[prf_idx]
			proof_facts = self.prf_facts[prf_idx]
			conclusion  = self.conclusion[prf_idx]
			input_str   = '$answer$ ; $proof$ ; $question$ = What is one single-hop inference? ; $context$ ='
			ctx         = ''
			sent_id_map = {}
			count       = 1
			added_sents = []
			for idx, inp in enumerate(self.inputs):
				if delete == 'delrule' and conclusion != 'Nothing.':
					if inp == proof_rule or inp in self.all_facts:
						sent_id_map[inp] = f'sent{count}'
						ctx += f' sent{count}: {inp}'
						count += 1
						added_sents.append(inp)
				elif delete == 'delfact' and conclusion != 'Nothing.':
					if inp in proof_facts or inp in self.all_rules:
						sent_id_map[inp] = f'sent{count}'
						ctx += f' sent{count}: {inp}'
						count += 1
						added_sents.append(inp)
				else:
					sent_id_map[inp] = f'sent{idx+1}'
					ctx += f' sent{idx+1}: {inp}'
					added_sents.append(inp)

			input_str  += ctx
			input_str  += tokenizer.eos_token
			input_list.append(input_str)

			if conclusion == 'Nothing.':
				output_str = tokenizer.pad_token + f'$answer$ = Nothing. ; $proof$ = None' + tokenizer.eos_token
			else:
				out_list = []
				for conc_idx, conc in enumerate(self.conclusion):
					plausible = True
					if self.prf_rule[conc_idx] not in added_sents:
						plausible = False
					for f in self.prf_facts[conc_idx]:
						if f not in added_sents:
							plausible = False
					if plausible:
						out_list.append(tokenizer.pad_token + \
							f'$answer$ = {conc} ; $proof$ = # {sent_id_map[self.prf_rule[conc_idx]]} {self.join_facts(self.prf_facts[conc_idx], sent_id_map)}' + \
							tokenizer.eos_token)

				output_str = delimiter.join(out_list) 	# To make sense of this check proofwriter_iterative_model.py

			output_list.append(output_str)

		input_ids  = [tokenizer.convert_tokens_to_ids(tokenizer.tokenize(input_s))[:tokenizer.model_max_length] for input_s in input_list]
		output_ids = [tokenizer.convert_tokens_to_ids(tokenizer.tokenize(output_s))[:tokenizer.model_max_length] for output_s in output_list]

		return input_ids, output_ids

	def tokenize(self, tokenizer, arch, split, delete=None, multiple=False):
		if arch == 't5_base' or arch == 't5_large':
			if multiple:
				return self.tokenize_ptlm_multiple(tokenizer)
			elif delete == 'delrule' or delete == 'delfact':
				return self.tokenize_ptlm_del(tokenizer, delete=delete)
			else:
				return self.tokenize_ptlm(tokenizer, delete=delete)
		else:
			raise NotImplementedError

	@classmethod
	def tokenize_batch(cls, tokenizer, batched_rules, batched_facts, lowercase=False):
		if lowercase:
			new_contexts = [list(map(str.lower, rules + facts)) for rules, facts in zip(batched_rules, batched_facts)]
			input_str    = ['$answer$ ; $proof$ ; $question$ = what is one single-hop inference? ; $context$ ='] * len(batched_rules)
		else:
			new_contexts = [list(rules + facts) for rules, facts in zip(batched_rules, batched_facts)]
			input_str    = ['$answer$ ; $proof$ ; $question$ = What is one single-hop inference? ; $context$ ='] * len(batched_rules)


		for idx, ctx in enumerate(new_contexts):
			ctx = ''.join([f' sent{i+1}: {x}' for i,x in enumerate(ctx)])
			input_str[idx]  += ctx
			input_str[idx]  += tokenizer.eos_token

		tokenized        = tokenizer(input_str, add_special_tokens=False, padding=True, truncation=True, return_tensors='pt')
		input_ids        = tokenized['input_ids']

		return input_ids
