from helper import *
from proofwriter_classes import *
import random
random.seed(10)

def get_pw_fname(args, subdir, split, is_staged=True, equivalence=False):
	if is_staged:
		fname = f'../data/raw/proofwriter/{args.world_assump}/{subdir}/meta-stage-{split}.jsonl'
	else:
		fname = f'../data/raw/proofwriter/{args.world_assump}/{subdir}/meta-{split}.jsonl'

	# change input filename for pararules
	if subdir == 'NatLang':
		fname = fname.replace('.jsonl', '-processed.jsonl')

	if equivalence:
		equiv_name = args.dataset.split('_')[-1]
		if 'v2' in equiv_name:
			version = 'v2'
			equiv_name = equiv_name.replace('v2', '')
		elif 'v3' in equiv_name:
			version = 'v3'
			equiv_name = equiv_name.replace('v3', '')
		else:
			version = ''

		if equiv_name in ['name', 'attr', 'rel', 'name$attr']:
			if equiv_name == 'name':
				equiv_name = 'name$comm'
			if equiv_name == 'attr':
				equiv_name = 'attrs'
			fname = fname.replace('.jsonl', f'-equiv_{equiv_name}{version}.jsonl')
		else:
			raise NotImplementedError

	# for birds electricity
	if subdir == 'birds-electricity':
		fname = fname.replace('.jsonl', f'-{args.dataset[-2:]}.jsonl') # replace .jsonl with -B2.jsonl for the dataset meta-test-B2

	return fname

def get_out_fname(args, split, key, is_staged=True, return_folder=False):
	if return_folder:
		if is_staged:
			return f'../data/processed/{args.dataset}/{args.world_assump}/{args.pw_model}/{args.arch}/{split}'
		else:
			return f'../data/processed/{args.dataset}/{args.world_assump}/{split}'
	else:
		if is_staged:
			return f'../data/processed/{args.dataset}/{args.world_assump}/{args.pw_model}/{args.arch}/{split}/{key}.pkl'
		else:
			return f'../data/processed/{args.dataset}/{args.world_assump}/{split}/{key}.pkl'

def get_pw_subdir(dataset):
	if ('_leq_' in dataset or '_eq_' in dataset): # handles, 'pw_leq_', 'pwu_leq_', 'pwq_leq_', 'pwur_leq_', 'pwi_leq_', 'pwqs_leq_', 'pwqf_leq_', 'pwqfs_leq_', 'pwuc_', 'pwurc_', 'pwqr_leq_'
		subdir = f'depth-{dataset.split("_")[2]}'
	elif 'pararules' in dataset: # handles, 'pw_pararules', 'pwu_pararules', 'pwq_pararules'
		subdir = 'NatLang'

	return subdir

def get_keys(pw_model):
	if pw_model == 'pw_rule' or pw_model == 'pw_fact':
		return ['input_ids', 'token_labels', 'token_mask']
	else:
		return ['input_ids', 'output_ids']

def is_valid_row(args, row_id):
	return (args.world_assump == 'OWA') or (args.world_assump == 'CWA' and (row_id.startswith('AttNoneg') or row_id.startswith('RelNoneg')))

def make_data_from_instance(data, output, keys):
	# TODO: change output[0], output[1] to dictionary, ie make it output[key1], output[key2]
	for i, key in enumerate(keys):
		data[key].append(output[i])

def pickle_dump_file(keys, data, split, args, is_staged=True):
	print(args)
	for key in keys:
		print(f'Dumping {len(data[key])} lines for dataset: {args.dataset} split: {split} key: {key}')
		print(f'file writing to = {get_out_fname(args, split, key, is_staged=is_staged)}')
		with open(get_out_fname(args, split, key, is_staged=is_staged), 'wb') as f:
			pickle.dump(data[key], f)

def get_row_chunks(stagefile, non_stagefile, use_equiv_id = False):
	# returns a list of the form [[non-stagefile1, stagefile1-add0, stagefile1-add1 ...], [], []]
	print(f'stagefile = {stagefile}, non_stagefile = {non_stagefile}')

	if use_equiv_id:
		id_to_use = 'equiv_id'
	else:
		id_to_use = 'id'

	stagefile_lines = []
	non_stagefile_lines = []
	with jsonlines.open(stagefile) as f:
		for i, row in enumerate(tqdm(f)):
			stagefile_lines.append(row)
	with jsonlines.open(non_stagefile) as f:
		for i, row in enumerate(tqdm(f)):
			non_stagefile_lines.append(row)
	i, j = 0, 0

	if not use_equiv_id:
		row_chunks = []
		while(i<len(non_stagefile_lines)):
			chunk = []
			chunk.append(non_stagefile_lines[i])
			while(j<len(stagefile_lines) and stagefile_lines[j][id_to_use].rsplit('-', 1)[0] == non_stagefile_lines[i][id_to_use]): # ie the dicts should correspond to the same id except the -add0/1/2 part
				chunk.append(stagefile_lines[j])
				j+=1
			row_chunks.append(chunk)
			i+=1
	else:
		row_chunks = []
		while(i<len(non_stagefile_lines)):
			chunk = []
			chunk.append(non_stagefile_lines[i])
			while(j<len(stagefile_lines) and (stagefile_lines[j][id_to_use].rsplit('-', 1)[0] + '_' + stagefile_lines[j][id_to_use].split('_', 1)[1]) == non_stagefile_lines[i][id_to_use]): # ie the dicts should correspond to the same id except the -add0/1/2 part
				chunk.append(stagefile_lines[j])
				j+=1
			row_chunks.append(chunk)
			i+=1

	return row_chunks

def main(args):

	# load tokenizer
	if args.arch == 'roberta_large_race':
		tokenizer = AutoTokenizer.from_pretrained("LIAMF-USP/roberta-large-finetuned-race")
	elif args.arch == 't5_base':
		tokenizer = AutoTokenizer.from_pretrained("t5-base")
	elif args.arch == 't5_large':
		tokenizer = AutoTokenizer.from_pretrained("t5-large")
	elif args.arch == 'roberta_large':
		tokenizer = AutoTokenizer.from_pretrained("roberta-large")
	else:
		print('Token type ids not implemented in tokenize call, will not work for bert models')
		raise NotImplementedError

	# check if the dataset is staged dataset or unstaged dataset
	if args.dataset.startswith('pwu_') or args.dataset.startswith('pwur_') or args.dataset.startswith('pwuc_') or args.dataset.startswith('pwurc_'):
		is_staged = False
	else:
		is_staged = True

	# load data
	for split in ['train', 'dev', 'test']:
		# pos -> some rule is selected, good_neg -> stop examples, bad_neg -> no rule possible to select
		pos_count, good_neg_count, bad_neg_count = 0, 0, 0

		print(f'Processing {split} split...')

		# make folder if not exists
		pathlib.Path(get_out_fname(args, split, None, is_staged=is_staged, return_folder=True)).mkdir(exist_ok=True, parents=True)

		data = ddict(list)

		if args.dataset.startswith('pw_leq'):
			# load the relevant original file and select all the data

			# we are using row chunks because we may have to sample 20 percent from D0 to 2, and for that we can't sample randomly from the rows of the stage json files
			# but rather we need to sample from the rows of the non - stage files, which is equivalent to saying that we can sample from row_chunks, which is a list of tuples
			# and each tuple corresponds to just one row (id) in the non - stage file
			if args.dataset == 'pw_leq_0to3':
				# COMBINE DEPTH 0, 1, 2 AND SAMPLE 20%. ADD DEPTH 3 TO THIS TO MAKE THE TRAIN/VALID/TEST DATA
				subdirs          = [get_pw_subdir(f'pw_leq_{i}') for i in range(4)]
				stagefiles       = [get_pw_fname(args, subdir, split, is_staged=True) for subdir in subdirs]
				non_stagefiles   = [get_pw_fname(args, subdir, split, is_staged=False) for subdir in subdirs]
				row_chunks       = [get_row_chunks(stagefiles[i], non_stagefiles[i]) for i in range(len(subdirs))]
				row_chunks012    = []
				for i in range(3):
					row_chunks012.extend(row_chunks[i])
				row_chunks012 = random.sample(row_chunks012, len(row_chunks012)//5)
				print(f'Length of row_chunks for dataset row_chunks012 is = {len(row_chunks012)}')
				row_chunks012.extend(row_chunks[3])
				row_chunks = row_chunks012 # Has all elemennts of depth 3 stage file, along with 20 percent from depth 0 to 2. Format islist of lists [[non-stagefile1, stagefile1-add0, stagefile1-add1 ...], [], []]
				print(f'Length of row_chunks for dataset pw_leq_0to3 is = {len(row_chunks)}')

			else:
				subdir    	  = get_pw_subdir(args.dataset)
				stagefile     = get_pw_fname(args, subdir, split, is_staged=True)
				non_stagefile = get_pw_fname(args, subdir, split, is_staged=False)
				row_chunks    = get_row_chunks(stagefile, non_stagefile) # gives back a list of lists [[non-stagefile1, stagefile1-add0, stagefile1-add1 ...], [], []]

			# Now row chunks is of the form [[non-stagefile1, stagefile1-add0, stagefile1-add1 ...], [], []]
			# we have to only consider the stage files, and iterate over all of them, for this part, ie pw_leq_X datasets
			if args.pw_model == 'pw_rule':
				for row_chunk in tqdm(row_chunks):
					non_stage_row = row_chunk[0] # row chunk is of the form [non-stagefile1, stagefile1-add0, stagefile1-add1 ...]
					stage_rows    = row_chunk[1:]
					if is_valid_row(args, non_stage_row['id']):
						# for i, row in enumerate(rows):
						for i, row in enumerate(stage_rows):
							instance = PWRuleSelectorInstance.from_json(row, tokenizer)
							output   = instance.tokenize(tokenizer, args.arch, split)
							make_data_from_instance(data, output, get_keys(args.pw_model))

			elif args.pw_model == 'pw_fact':
				# for i in tqdm(range(len(rows)-1)):
					# we are currently not worried about [Attnnoneg1_add0, Attnnoneg1_add1, Attnnoneg1_add2, Relnoneg_add0, ...]
					# since Attnnoneg1_add2 will anyway not have any rule to be selected, so no data point will be generated using
					# it and we hence are not worried about the following pjson dict (Relnoneg_add0)

				for row_chunk in tqdm(row_chunks):
					non_stage_row = row_chunk[0] # row chunk is of the form [non-stagefile1, stagefile1-add0, stagefile1-add1 ...]
					stage_rows    = row_chunk[1:]
					if is_valid_row(args, non_stage_row['id']):
						# for i, row in enumerate(rows):
						for i in range(len(stage_rows)-1):
							instances = PWFactSelectorInstance.from_json(stage_rows[i], stage_rows[i+1], tokenizer) # we get a list of datapoints from one json line, for the fact selector
							for instance in instances:
								output   = instance.tokenize(tokenizer, args.arch, split)
								make_data_from_instance(data, output, get_keys(args.pw_model))

			elif args.pw_model == 'pw_reasoner':
				for row_chunk in tqdm(row_chunks):
					non_stage_row = row_chunk[0] # row chunk is of the form [non-stagefile1, stagefile1-add0, stagefile1-add1 ...]
					stage_rows    = row_chunk[1:]
					if is_valid_row(args, non_stage_row['id']):
						# for i, row in enumerate(rows):
						for i, row in enumerate(stage_rows):
							instances = PWReasonerInstance.from_json(row)
							for instance in instances:
								output   = instance.tokenize(tokenizer, args.arch, split)
								make_data_from_instance(data, output, get_keys(args.pw_model))

		# data for pararules + D3
		elif args.dataset.startswith('pw_pararules'):
			print(f'processing the {args.dataset} dataset')
			# load the relevant original files for pararules and pw_leq_3 and select all the data from both
			if split == 'train':
				all_subdirs = ['NatLang', 'depth-3']
			else:
				all_subdirs = ['NatLang']
			if args.pw_model == 'pw_rule':
				for subdir in all_subdirs:
					print(f'loading {subdir} dataset')
					with jsonlines.open(get_pw_fname(args, subdir, split)) as f:
						for i, row in enumerate(tqdm(f)):
							if is_valid_row(args, row['id']):
								instance = PWRuleSelectorInstance.from_json(row, tokenizer)
								output   = instance.tokenize(tokenizer, args.arch, split)
								make_data_from_instance(data, output, get_keys(args.pw_model))

			elif args.pw_model == 'pw_fact':
				for subdir in all_subdirs:
					print(f'loading {subdir} dataset')
					with jsonlines.open(get_pw_fname(args, subdir, split)) as f:
						rows = []
						for i, row in enumerate(tqdm(f)):
							if is_valid_row(args, row['id']):
								rows.append(row)
						for i in tqdm(range(len(rows)-1)):
							# we are currently not worried about [Attnnoneg1_add0, Attnnoneg1_add1, Attnnoneg1_add2, Relnoneg_add0, ...]
							# since Attnnoneg1_add2 will anyway not have any rule to be selected, so no data point will be generated using
							# it and we hence are not worried about the following pjson dict (Relnoneg_add0)
							instances = PWFactSelectorInstance.from_json(rows[i], rows[i+1], tokenizer) # we get a list of datapoints from one json line, for the fact selector
							for instance in instances:
								output   = instance.tokenize(tokenizer, args.arch, split)
								make_data_from_instance(data, output, get_keys(args.pw_model))

			# Note: For low resource the first dataset that is appended will be pararules, and we sample the fraction of data from it
			# after this D3 is appended as it is
			# Only doing this for pw_reasoner
			elif args.pw_model == 'pw_reasoner':
				for subdir in all_subdirs:
					with jsonlines.open(get_pw_fname(args, subdir, split)) as f:
						for i, row in enumerate(tqdm(f)):
							if is_valid_row(args, row['id']):
								instances = PWReasonerInstance.from_json(row)
								for instance in instances:
									output   = instance.tokenize(tokenizer, args.arch, split)
									make_data_from_instance(data, output, get_keys(args.pw_model))

					#For low resource
					if args.dataset.startswith('pw_pararules0.'):
						# low resource pararules (x% of total), but only for the test set
						if subdir == 'NatLang' and split == 'train':
							keys = get_keys(args.pw_model)
							frac = float(args.dataset[12:15]) # 0.1/0.3 etc
							indexs = [x for x in range(int(frac*len(data['input_ids'])))]
							length_data_full = len(data['input_ids'])
							random.shuffle(indexs) # shuffling the indexs
							print(f'sampled num datapts = {len(indexs)}, original num datapoints = {length_data_full}')
							for key in keys:
								data[key] = [data[key][i] for i in indexs]


#########################################

		elif args.dataset.startswith('pwu_leq') or args.dataset.startswith('pwuc_leq'):
			subdir = get_pw_subdir(args.dataset)
			if '_eq_' in args.dataset:
				qdep_required = int(args.dataset.split("_")[-1]) # the depth which we want
			else:
				qdep_required = None
			with jsonlines.open(get_pw_fname(args, subdir, split, is_staged=False)) as f:
				for i, row in enumerate(tqdm(f)):
					if args.dataset.startswith('pwuc_leq'):
						instances = PWInstance.from_json(row, qdep_required, lowercase=False)
					else:
						instances = PWInstance.from_json(row, qdep_required, lowercase=True)
					for instance in instances:
						data['rules'].append(instance.rules)
						data['facts'].append(instance.facts)
						data['ques'].append(instance.ques)
						data['answer'].append(instance.answer)
						data['proof'].append(instance.proofs)
						data['qdep'].append(instance.qdep)
						data['equiv_id'].append(instance.equiv_id)
						data['all_conclusions'].append(instance.all_conclusions)

		elif (args.dataset.startswith('pwu_pararules') and (not '_eq_' in args.dataset)) or (args.dataset.startswith('pwuc_pararules') and (not '_eq_' in args.dataset)):
			if split == 'train':
				all_subdirs = ['NatLang', 'depth-3']
			else:
				all_subdirs = ['NatLang']
			for subdir in all_subdirs:
				with jsonlines.open(get_pw_fname(args, subdir, split, is_staged=False)) as f:
					for i, row in enumerate(tqdm(f)):
						if args.dataset.startswith('pwuc_'):
							instances = PWInstance.from_json(row, lowercase=False)
						else:
							instances = PWInstance.from_json(row, lowercase=True)
						for instance in instances:
							data['rules'].append(instance.rules)
							data['facts'].append(instance.facts)
							data['ques'].append(instance.ques)
							data['answer'].append(instance.answer)
							data['proof'].append(instance.proofs)
							data['qdep'].append(instance.qdep)
							data['equiv_id'].append(instance.equiv_id)
							data['all_conclusions'].append(instance.all_conclusions)

		elif (args.dataset.startswith('pwu_pararules') and ('_eq_' in args.dataset)) or (args.dataset.startswith('pwuc_pararules') and ('_eq_' in args.dataset)):
			# load the relevant original files for pararules and pw_leq_3 and select all the data from both
			if split == 'test' or split == 'dev':
				assert '_eq_' in args.dataset
				qdep_required = int(args.dataset.split("_")[-1]) # the depth which we want
				subdir = 'NatLang'
				with jsonlines.open(get_pw_fname(args, subdir, split, is_staged=False)) as f:
					for i, row in enumerate(tqdm(f)):
						instances = PWInstance.from_json(row, qdep_required, dataset=args.dataset)
						for instance in instances:
							data['rules'].append(instance.rules)
							data['facts'].append(instance.facts)
							data['ques'].append(instance.ques)
							data['answer'].append(instance.answer)
							data['proof'].append(instance.proofs)
							data['qdep'].append(instance.qdep)
							data['equiv_id'].append(instance.equiv_id)

		# Birds electricity datasets
		elif args.dataset.startswith('pwu_B') or args.dataset.startswith('pwu_E'):
			if split == 'test': # birds electricity only has test datasets
				subdir = 'birds-electricity'
				with jsonlines.open(get_pw_fname(args, subdir, split, is_staged=False)) as f:
					for i, row in enumerate(tqdm(f)):
						instances = PWInstance.from_json(row)
						for instance in instances:
							data['rules'].append(instance.rules)
							data['facts'].append(instance.facts)
							data['ques'].append(instance.ques)
							data['answer'].append(instance.answer)
							data['proof'].append(instance.proofs)
							data['qdep'].append(instance.qdep)
							data['equiv_id'].append(instance.equiv_id)
							data['all_conclusions'].append(instance.all_conclusions)
#########################################
		elif args.dataset.startswith('pwur_') or args.dataset.startswith('pwurc_'):
			if split == 'test':
				subdir = get_pw_subdir(args.dataset)
				if '_eq_' in args.dataset:
					qdep_required = int(args.dataset.split("_")[-2]) # the depth which we want (NOTE This is -2 only for this dataset)
				else:
					qdep_required = None
				with jsonlines.open(get_pw_fname(args, subdir, split, is_staged=False, equivalence=True)) as f:
					for i, row in enumerate(tqdm(f)):
						if args.dataset.startswith('pwurc_'):
							instances = PWInstance.from_json(row, qdep_required, lowercase=False)
						else:
							instances = PWInstance.from_json(row, qdep_required, lowercase=True)

						for instance in instances:
							data['rules'].append(instance.rules)
							data['facts'].append(instance.facts)
							data['ques'].append(instance.ques)
							data['answer'].append(instance.answer)
							data['proof'].append(instance.proofs)
							data['qdep'].append(instance.qdep)
							data['equiv_id'].append(instance.equiv_id)
							data['all_conclusions'].append(instance.all_conclusions)

#########################################

		elif args.dataset.startswith('pwr_'):
			# equivalence, use_quiv_id will be true everywhere in this part since this part is specifically for pwr_
			# TODO: combine this part with pw_ dataset generatin code above, like we did for pwq_ and pwqr_

			assert args.pw_model == 'pw_reasoner' # just for the time being

			# load the relevant original file and select all the data

			# we are using row chunks because we may have to sample 20 percent from D0 to 2, and for that we can't sample randomly from the rows of the stage json files
			# but rather we need to sample from the rows of the non - stage files, which is equivalent to saying that we can sample from row_chunks, which is a list of tuples
			# and each tuple corresponds to just one row (id) in the non - stage file
			if args.dataset.startswith('pwr_leq_0to3'):
				# COMBINE DEPTH 0, 1, 2 AND SAMPLE 20%. ADD DEPTH 3 TO THIS TO MAKE THE TRAIN/VALID/TEST DATA
				subdirs          = [get_pw_subdir(f'pwr_leq_{i}') for i in range(4)]
				stagefiles       = [get_pw_fname(args, subdir, split, is_staged=True, equivalence=True) for subdir in subdirs]
				non_stagefiles   = [get_pw_fname(args, subdir, split, is_staged=False, equivalence=True) for subdir in subdirs]
				row_chunks       = [get_row_chunks(stagefiles[i], non_stagefiles[i], use_equiv_id=True) for i in range(len(subdirs))]
				row_chunks012    = []
				for i in range(3):
					row_chunks012.extend(row_chunks[i])
				row_chunks012 = random.sample(row_chunks012, len(row_chunks012)//5)
				print(f'Length of row_chunks for dataset row_chunks012 is = {len(row_chunks012)}')
				row_chunks012.extend(row_chunks[3])
				row_chunks = row_chunks012 # Has all elemennts of depth 3 stage file, along with 20 percent from depth 0 to 2. Format islist of lists [[non-stagefile1, stagefile1-add0, stagefile1-add1 ...], [], []]
				print(f'Length of row_chunks for dataset pw_leq_0to3 is = {len(row_chunks)}')

			else:
				subdir    	  = get_pw_subdir(args.dataset)
				stagefile     = get_pw_fname(args, subdir, split, is_staged=True, equivalence=True)
				non_stagefile = get_pw_fname(args, subdir, split, is_staged=False, equivalence=True)
				row_chunks    = get_row_chunks(stagefile, non_stagefile, use_equiv_id=True) # gives back a list of lists [[non-stagefile1, stagefile1-add0, stagefile1-add1 ...], [], []]

			# Now row chunks is of the form [[non-stagefile1, stagefile1-add0, stagefile1-add1 ...], [], []]
			# we have to only consider the stage files, and iterate over all of them, for this part, ie pw_leq_X datasets
			for row_chunk in tqdm(row_chunks):
				non_stage_row = row_chunk[0] # row chunk is of the form [non-stagefile1, stagefile1-add0, stagefile1-add1 ...]
				stage_rows    = row_chunk[1:]
				if is_valid_row(args, non_stage_row['id']):
					# for i, row in enumerate(rows):
					for i, row in enumerate(stage_rows):
						instances = PWReasonerInstance.from_json(row)
						for instance in instances:
							output   = instance.tokenize(tokenizer, args.arch, split)
							make_data_from_instance(data, output, get_keys(args.pw_model))

#########################################

		elif args.dataset.startswith('pwq_leq') or args.dataset.startswith('pwqs_leq') or args.dataset.startswith('pwqf_leq') or args.dataset.startswith('pwqfs_leq') or args.dataset.startswith('pwqm_leq') or args.dataset.startswith('pwqr_leq') or args.dataset.startswith('pwnoq_leq'):
			equiv  = args.dataset.startswith('pwqr_leq') == True
			use_equiv_id = args.dataset.startswith('pwqr_leq') == True
			noq = args.dataset.startswith('pwnoq_leq')
			# load the relevant original file and select all the data
			if args.dataset == 'pwq_leq_0to3' or args.dataset.startswith('pwqr_leq_0to3') or args.dataset.startswith('pwnoq_leq_0to3'):
				# COMBINE DEPTH 0, 1,2 AND SAMPLE 20%. ADD DEPTH 3 TO THIS TO MAKE THE TRAIN/VALID/TEST DATA
				subdirs        = [get_pw_subdir(f'pwq_leq_{i}') for i in range(4)] # subdirs are common for both pwq and pwqr datasets
				stagefiles     = [get_pw_fname(args, subdir, split, is_staged=True, equivalence=equiv) for subdir in subdirs]
				non_stagefiles = [get_pw_fname(args, subdir, split, is_staged=False, equivalence=equiv) for subdir in subdirs]
				row_chunks     = [get_row_chunks(stagefiles[i], non_stagefiles[i], use_equiv_id) for i in range(len(subdirs))]
				row_chunks012  = []
				for i in range(3):
					row_chunks012.extend(row_chunks[i])
				row_chunks012 = random.sample(row_chunks012, len(row_chunks012)//5)
				# print(f'Length of row_chunks for dataset row_chunks012 is = {len(row_chunks012)}')
				row_chunks012.extend(row_chunks[3])
				row_chunks = row_chunks012 # Has all elemennts of depth 3 stage file, along with 20 percent from depth 0 to 2. Format is list of lists [[non-stagefile1, stagefile1-add0, stagefile1-add1 ...], [], []]
				print(f'Length of row_chunks for dataset {args.dataset} is = {len(row_chunks)}')
			else:
				subdir    	  = get_pw_subdir(args.dataset)
				stagefile     = get_pw_fname(args, subdir, split, is_staged=True, equivalence=equiv)
				non_stagefile = get_pw_fname(args, subdir, split, is_staged=False, equivalence=equiv)
				row_chunks    = get_row_chunks(stagefile, non_stagefile, use_equiv_id) # gives back a list of lists [[non-stagefile1, stagefile1-add0, stagefile1-add1 ...], [], []]

			if(args.dataset.endswith('nostop')):
				nostop = True
			else:
				nostop = False

			if args.pw_model == 'pw_rule':
				for row_chunk in tqdm(row_chunks):
					non_stage_row = row_chunk[0] # row chunk is of the form [non-stagefile1, stagefile1-add0, stagefile1-add1 ...]
					stage_rows    = row_chunk[1:]
					if is_valid_row(args, non_stage_row['id']):
						if args.dataset == 'pwqm_leq_3':
							# print('we are using the depth 3 multirule dataset')
							instances = PWQRuleInstance.from_json_multirule(non_stage_row, stage_rows, nostop=nostop) # returns a list of instances [instance1, instance2, ....]
						else:
							instances = PWQRuleInstance.from_json(non_stage_row, stage_rows, nostop=nostop) # returns a list of instances [instance1, instance2, ....]
						for instance in instances:
							if (args.dataset.startswith('pwqf_leq') or args.dataset.startswith('pwqfs_leq')) and instance.strategy not in ['proof', 'inv-proof']:
								continue
							else:
								output = instance.tokenize(tokenizer, args.arch, split, noq = noq)
								make_data_from_instance(data, output, get_keys(args.pw_model))

								# Do some accounting and upsampling if required
								if sum(instance.labels) > 0.0:
									pos_count += 1
								else:
									assert sum(instance.labels) == 0.0
									if instance.strategy in ['proof', 'inv-proof']:
										good_neg_count += 1

										# upsample -ve instances (in train split) if upsampling ratio defined
										if args.upsampling > 1 and split == 'train':
											assert sum(output[1]) == 0.0
											for _ in range(args.upsampling - 1):
												good_neg_count += 1
												make_data_from_instance(data, output, get_keys(args.pw_model))
									else:
										bad_neg_count += 1
				print(f'Total: {pos_count + good_neg_count + bad_neg_count}, Pos: {pos_count}, Good Neg: {good_neg_count}, Bad Neg: {bad_neg_count}')
				print(f'Ratio: {pos_count / good_neg_count}')

			elif args.pw_model == 'pw_fact':
				for row_chunk in tqdm(row_chunks):
					non_stage_row = row_chunk[0] # row chunk is of the form [non-stagefile1, stagefile1-add0, stagefile1-add1 ...]
					stage_rows    = row_chunk[1:]
					if is_valid_row(args, non_stage_row['id']):
						if args.dataset == 'pwqm_leq_3':
							# print('we are using the depth 3 multirule dataset')
							instances = PWQFactInstance.from_json_multirule(non_stage_row, stage_rows) # returns a list of instances [instance1, instance2, ....]
						else:
							instances = PWQFactInstance.from_json(non_stage_row, stage_rows) # returns a list of instances [instance1, instance2, ....]
						for instance in instances:
							output   = instance.tokenize(tokenizer, args.arch, split, noq = noq)
							make_data_from_instance(data, output, get_keys(args.pw_model))

		elif args.dataset.startswith('pwq_pararules'):
			# load the relevant original files for pararules and pw_leq_3 and select all the data from both
			if split == 'train':
				all_subdirs = ['NatLang', 'depth-3']
			else:
				all_subdirs = ['NatLang']
			for subdir in all_subdirs:
				# load the relevant original file and select all the data
				if args.pw_model == 'pw_rule':
					stagefile     = get_pw_fname(args, subdir, split, is_staged=True)
					non_stagefile = get_pw_fname(args, subdir, split, is_staged=False)
					row_chunks    = get_row_chunks(stagefile, non_stagefile) # gives back a list of lists [[non-stagefile1, stagefile1-add0, stagefile1-add1 ...], [], []]

					for row_chunk in tqdm(row_chunks):
						non_stage_row = row_chunk[0] # row chunk is of the form [non-stagefile1, stagefile1-add0, stagefile1-add1 ...]
						stage_rows    = row_chunk[1:]
						if is_valid_row(args, non_stage_row['id']):
							instances = PWQRuleInstance.from_json(non_stage_row, stage_rows) # returns a list of instances [instance1, instance2, ....]
							for instance in instances:
								output   = instance.tokenize(tokenizer, args.arch, split)
								make_data_from_instance(data, output, get_keys(args.pw_model))

				elif args.pw_model == 'pw_fact':
					stagefile     = get_pw_fname(args, subdir, split, is_staged=True)
					non_stagefile = get_pw_fname(args, subdir, split, is_staged=False)
					row_chunks    = get_row_chunks(stagefile, non_stagefile) # gives back a list of lists [[non-stagefile1, stagefile1-add0, stagefile1-add1 ...], [], []]

					for row_chunk in tqdm(row_chunks):
						non_stage_row = row_chunk[0] # row chunk is of the form [non-stagefile1, stagefile1-add0, stagefile1-add1 ...]
						stage_rows    = row_chunk[1:]
						if is_valid_row(args, non_stage_row['id']):
							instances = PWQFactInstance.from_json(non_stage_row, stage_rows) # returns a list of instances [instance1, instance2, ....]
							for instance in instances:
								output   = instance.tokenize(tokenizer, args.arch, split)
								make_data_from_instance(data, output, get_keys(args.pw_model))

				# Note the first dataset that is appended will be pararules, and we sample the fraction of data from it
				# after this D3 is appended as it is
				if args.dataset.startswith('pwq_pararules0.'):
					# low resource pararules (x% of total), but only for the test set
					if subdir == 'NatLang' and split == 'train':
						keys = get_keys(args.pw_model)
						frac = float(args.dataset[13:16]) # 0.1/0.3 etc
						indexs = [x for x in range(int(frac*len(data['input_ids'])))]
						length_data_full = len(data['input_ids'])
						random.shuffle(indexs) # shuffling the indexs
						print(f'original num datapts = {len(indexs)}, sampled num datapoints = {length_data_full}')
						for key in keys:
							data[key] = [data[key][i] for i in indexs]




#########################################

		elif args.dataset.startswith('pwi_'):
			subdir        = get_pw_subdir(args.dataset)
			stagefile     = get_pw_fname(args, subdir, split, is_staged=True)
			non_stagefile = get_pw_fname(args, subdir, split, is_staged=False)
			row_chunks    = get_row_chunks(stagefile, non_stagefile)
			tokenizer     = T5Tokenizer.from_pretrained('t5-large')

			for row_chunk in tqdm(row_chunks):
				non_stage_row = row_chunk[0]
				stage_rows    = row_chunk[1:]
				for idx in range(len(stage_rows)-1):
					if args.dataset.endswith('_del'):
						instance = PWIterativeInstance.from_json(stage_rows[idx], stage_rows[idx+1], tokenizer)
					else:
						instance = PWIterativeInstance.from_json_multiple(stage_rows[idx], stage_rows[idx+1], tokenizer)
					output = instance.tokenize(tokenizer, args.arch, split, delete=args.dataset.split('_')[-1], multiple=('del' not in args.dataset))
					if 'del' in args.dataset:
						for inp, out in zip(*output):
							make_data_from_instance(data, [inp, out], get_keys(args.pw_model))
					else:
						make_data_from_instance(data, output, get_keys(args.pw_model))

				# Generate stop criteria using last row
				if args.dataset.endswith('_del'):
					instance = PWIterativeInstance.from_json(stage_rows[idx+1], None, tokenizer)
				else:
					instance = PWIterativeInstance.from_json_multiple(stage_rows[idx+1], None, tokenizer)
				output = instance.tokenize(tokenizer, args.arch, split, delete=args.dataset.split('_')[-1], multiple=('del' not in args.dataset))
				if 'del' in args.dataset:
					for inp, out in zip(*output):
						make_data_from_instance(data, [inp, out], get_keys(args.pw_model))
				else:
					make_data_from_instance(data, output, get_keys(args.pw_model))

#########################################
# write the data in pickle format to processed folder

		if args.dataset.startswith('pwu_') or args.dataset.startswith('pwur_') or args.dataset.startswith('pwuc_') or args.dataset.startswith('pwurc_'):
			keys = ['rules', 'facts', 'ques', 'answer', 'proof', 'qdep', 'equiv_id', 'all_conclusions']
			if args.dataset.startswith('pwu_B') or args.dataset.startswith('pwu_E'):
				if split=='test':
					pickle_dump_file(keys, data, split, args, is_staged=False)
			else:
				pickle_dump_file(keys, data, split, args, is_staged=False)

		elif args.dataset.startswith('pwq_') or args.dataset.startswith('pw_') or args.dataset.startswith('pwi_') \
			or args.dataset.startswith('pwqs_') or args.dataset.startswith('pwqf_') or args.dataset.startswith('pwqfs_') \
			or args.dataset.startswith('pwqm_') or args.dataset.startswith('pwqr_') or args.dataset.startswith('pwr_') or args.dataset.startswith('pwnoq_'):
			pickle_dump_file(get_keys(args.pw_model), data, split, args, is_staged=True)

	print('\n***********************NOTICE*******************\n')
	print('Did you create any new folders and want to give access?')
	print('Please run: chmod -R 777 <folder>')
	print('\n***********************NOTICE-END*******************\n')


if __name__ == '__main__':
	parser = argparse.ArgumentParser(description='Preprocess data')

	# pw_ --> staged, pwu_ --> unstaged (evaluation), pwq_ --> question augmented, pwur_ --> robustness, pwi_ --> original proofwriter iterative,
	# pwqs_ -> upsampled negatives, pwqf_ -> filtered bad negatives, pwuc_ --> unstaged (evaluation) with cased, pwurc_ --> robustness with cased,
	# pwqr_ -> question augmented robustness dataset, pwr_ -> question not augmented robustnes dataset,
	parser.add_argument('--dataset', choices=[	'pw_leq_0', 'pw_leq_1', 'pw_leq_2', 'pw_leq_3', 'pw_leq_5',\
												'pw_pararules_leq_3', 'pw_pararules0.1_leq_3', 'pw_pararules0.5_leq_3', \
												'pw_leq_0to3',\

												'pwu_leq_0', 'pwu_leq_1', 'pwu_leq_2', 'pwu_leq_3', 'pwu_leq_5', 'pwu_leq_0to3', \
												'pwu_leq_5_eq_0', 'pwu_leq_5_eq_1', 'pwu_leq_5_eq_2', 'pwu_leq_5_eq_3', 'pwu_leq_5_eq_4', 'pwu_leq_5_eq_5', 'pwu_leq_5_eq_100',\
												'pwu_leq_3_eq_0', 'pwu_leq_3_eq_1', 'pwu_leq_3_eq_2', 'pwu_leq_3_eq_3', 'pwu_leq_3_eq_100',\
												'pwu_pararules_leq_3', 'pwu_B1', 'pwu_B2', 'pwu_E1', 'pwu_E2', 'pwu_E3', 'pwu_E4',\

												'pwu_pararules_eq_4', 'pwu_pararules_eq_3', 'pwu_pararules_eq_2', 'pwu_pararules_eq_1', 'pwu_pararules_eq_0',\
												'pwuc_pararules_eq_4', 'pwuc_pararules_eq_3', 'pwuc_pararules_eq_2', 'pwuc_pararules_eq_1', 'pwuc_pararules_eq_0',\
												'pwuc_leq_5', 'pwuc_leq_5_eq_0', 'pwuc_leq_5_eq_1', 'pwuc_leq_5_eq_2', 'pwuc_leq_5_eq_3', 'pwuc_leq_5_eq_4', 'pwuc_leq_5_eq_5', 'pwuc_pararules_leq_3', \
												'pwuc_leq_3_eq_0', 'pwuc_leq_3_eq_1', 'pwuc_leq_3_eq_2', 'pwuc_leq_3_eq_3', 'pwuc_leq_3_eq_100',\

												'pwq_leq_0', 'pwq_leq_1', 'pwq_leq_2', 'pwq_leq_3', 'pwq_leq_5',\
												'pwq_pararules_leq_3',\
												'pwq_leq_0to3', \
												'pwq_pararules0.1_leq_3', 'pwq_pararules0.3_leq_3', 'pwq_pararules0.5_leq_3',\

												# Not augmenting question but using the same data/files/code as the q augmented case
												'pwnoq_leq_0to3', \


												'pwq_leq_3_nostop', \
												# multiple rules can be selected in the following dataset. the corresponding fact dataset will have more number of datapoints corresponding to these multiple rules
												'pwqm_leq_3', \

												'pwqs_leq_3', 'pwqs_leq_5',\
												'pwqf_leq_3', 'pwqf_leq_5',\
												'pwqfs_leq_3', 'pwqfs_leq_5',\

												'pwur_leq_5_name', 'pwur_leq_5_attr', 'pwur_leq_5_rel', \
												'pwurc_leq_5_name', 'pwurc_leq_5_attr', 'pwurc_leq_5_rel', \
												'pwur_leq_3_name', 'pwur_leq_3_eq_0_name', 'pwur_leq_3_eq_1_name', 'pwur_leq_3_eq_2_name', 'pwur_leq_3_eq_3_name', 'pwur_leq_3_eq_100_name', \
												'pwur_leq_3_attr', 'pwur_leq_3_eq_0_attr', 'pwur_leq_3_eq_1_attr', 'pwur_leq_3_eq_2_attr', 'pwur_leq_3_eq_3_attr', 'pwur_leq_3_eq_100_attr', \
												'pwur_leq_3_rel', 'pwur_leq_3_eq_0_rel', 'pwur_leq_3_eq_1_rel', 'pwur_leq_3_eq_2_rel', 'pwur_leq_3_eq_3_rel', 'pwur_leq_3_eq_100_rel', \
												'pwurc_leq_3_name', 'pwurc_leq_3_eq_0_name', 'pwurc_leq_3_eq_1_name', 'pwurc_leq_3_eq_2_name', 'pwurc_leq_3_eq_3_name', 'pwurc_leq_3_eq_100_name', \
												'pwurc_leq_3_attr', 'pwurc_leq_3_eq_0_attr', 'pwurc_leq_3_eq_1_attr', 'pwurc_leq_3_eq_2_attr', 'pwurc_leq_3_eq_3_attr', 'pwurc_leq_3_eq_100_attr', \
												'pwurc_leq_3_rel', 'pwurc_leq_3_eq_0_rel', 'pwurc_leq_3_eq_1_rel', 'pwurc_leq_3_eq_2_rel', 'pwurc_leq_3_eq_3_rel', 'pwurc_leq_3_eq_100_rel', \

												'pwur_leq_3_namev2', 'pwur_leq_3_attrv2', 'pwur_leq_3_attrv3', 'pwur_leq_3_name$attr', \
												'pwurc_leq_3_namev2', 'pwurc_leq_3_attrv2', 'pwurc_leq_3_attrv3', 'pwurc_leq_3_name$attr',\


												'pwr_leq_5_name', 'pwr_leq_5_attr', \
												'pwr_leq_0to3_name', 'pwr_leq_0to3_attr', 'pwr_leq_0to3_rel',\
												'pwr_leq_3_name', 'pwr_leq_3_attr', 'pwr_leq_3_rel'\

												'pwqr_leq_5_attr', 'pwqr_leq_5_name', \
												'pwqr_leq_3_attr', 'pwqr_leq_3_name', \
												'pwqr_leq_3_rel', \

												'pwqr_leq_0to3_attr', 'pwqr_leq_0to3_name',\

												'pwi_leq_3', 'pwi_leq_5', 'pwi_leq_5_del', 'pwi_leq_5_delfact', 'pwi_leq_5_delrule'])

	parser.add_argument('--pw_model', choices=['pw_rule', 'pw_fact', 'pw_reasoner', 'pw_iter'])
	parser.add_argument('--world_assump', default='OWA', choices=['CWA', 'OWA'])
	parser.add_argument('--arch', default='roberta_large', choices=['roberta_large_race', 't5_base', 'roberta_large', 't5_large'])
	parser.add_argument('--upsampling', default=1, type=int)
	parser.add_argument('--subsample', action='store_true')
	args = parser.parse_args()

	main(args)
