from torch.nn.functional import pad
from helper import *


class DataModule(pl.LightningDataModule):

	def __init__(self, dataset, train_dataset, dev_dataset, test_dataset, arch, train_batch_size=32, eval_batch_size=32, num_workers=10, pad_idx=0, stopcls = False, stopsep = False, sep_token_id = -1, cls_token_id = -1):
		super().__init__()
		self.p                  = types.SimpleNamespace()
		self.p.dataset          = dataset
		self.p.train_dataset    = train_dataset		# used in load_dataset()
		self.p.dev_dataset      = dev_dataset		# used in load_dataset()
		self.p.test_dataset     = test_dataset		# used in load_dataset()
		self.p.arch             = arch
		self.p.train_batch_size = train_batch_size
		self.p.eval_batch_size  = eval_batch_size
		self.p.num_workers      = num_workers
		self.p.pad_idx          = pad_idx
		self.p.stopcls          = stopcls
		self.p.stopsep			= stopsep
		self.p.sep_token_id		= sep_token_id
		self.p.cls_token_id		= cls_token_id

	def load_dataset(self, split):
		if self.p.dataset.startswith('rt_') or self.p.dataset.startswith('rica') or self.p.dataset.startswith('pararules'):
			all_folders = [f'../data/processed/{x}/{self.p.arch}/{split}/' for x in getattr(self.p, f'{split}_dataset').split(',')]
			print('allfolders', all_folders)
			dataset = ddict(list)
			for key in ['input_ids', 'label', 'ctx_ids', 'stmt_ids']:
				for folder in all_folders:
					with open(folder + f'{key}.pkl', 'rb') as f:
						tmp          = pickle.load(f)
						dataset[key] = dataset[key] + tmp

		elif self.p.dataset.startswith('qasc_'):
			folder = f'../data/processed/{self.p.dataset}/{self.p.arch}/{split}/'
			dataset = ddict(list)
			for key in ['input_ids', 'type_ids', 'label']:
				with open(folder + f'{key}.pkl', 'rb') as f:
					dataset[key] = pickle.load(f)

		elif self.p.dataset.startswith('pw_') or self.p.dataset.startswith('pwr_'):
			if self.p.dataset.startswith('pwr_leq'):
				world_assump = self.p.dataset.split('_')[-2]
				maindir      = "_".join(self.p.dataset.split('_')[:-2])
				subdir       = 'pw_' + self.p.dataset.split('_')[-1]  # can be pw_rule, pw_fact ...
			elif self.p.dataset.startswith('pw_leq'):
				world_assump = self.p.dataset.split('_')[3]
				maindir      = "_".join(self.p.dataset.split('_')[:3])
				subdir       = 'pw_' + self.p.dataset.split('_', 4)[-1]  # can be pw_rule, pw_fact ...
			elif self.p.dataset.startswith('pw_pararules') and self.p.dataset.startswith('pw_pararules_leq_3')!=True and self.p.dataset.startswith('pw_pararules0.1_leq_3')!=True and self.p.dataset.startswith('pw_pararules0.3_leq_3')!=True and self.p.dataset.startswith('pw_pararules0.5_leq_3')!=True:
				assert self.p.dataset.startswith('pw_pararules_leq_3') != True
				world_assump = self.p.dataset.split('_')[2]
				maindir      = "_".join(self.p.dataset.split('_')[:2])
				subdir       = 'pw_' + self.p.dataset.split('_')[-1]  # can be pw_rule, pw_fact ...
			elif self.p.dataset.startswith('pw_pararules_leq_3') or self.p.dataset.startswith('pw_pararules0.1_leq_3') or self.p.dataset.startswith('pw_pararules0.3_leq_3') or self.p.dataset.startswith('pw_pararules0.5_leq_3'):
				world_assump = self.p.dataset.split('_')[4]
				maindir      = "_".join(self.p.dataset.split('_')[:4])
				subdir       = 'pw_' + self.p.dataset.split('_')[-1]  # can be pw_rule, pw_fact ...

			folder = f'../data/processed/{maindir}/{world_assump}/{subdir}/{self.p.arch}/{split}/'
			print(f'data being used from the folder = {folder}')
			dataset = ddict(list)
			if 'reasoner' in self.p.dataset:
				keys = ['input_ids', 'output_ids']
			else:
				keys = ['input_ids', 'token_labels', 'token_mask']
			for key in keys:
				with open(folder + f'{key}.pkl', 'rb') as f:
					dataset[key] = pickle.load(f)

		elif self.p.dataset.startswith('pwi_'):
			world_assump = self.p.dataset.split('_')[-2]
			maindir      = "_".join(self.p.dataset.split('_')[:-2])
			subdir       = 'pw_' + self.p.dataset.split('_')[-1]
			folder       = f'../data/processed/{maindir}/{world_assump}/{subdir}/{self.p.arch}/{split}/'
			print(f'data being used from the folder = {folder}')
			dataset = ddict(list)
			for key in ['input_ids', 'output_ids']:
				with open(folder + f'{key}.pkl', 'rb') as f:
					dataset[key] = pickle.load(f)

		elif self.p.dataset.startswith('pwq_') or self.p.dataset.startswith('pwqs_') or self.p.dataset.startswith('pwqf_') or self.p.dataset.startswith('pwqfs_') or self.p.dataset.startswith('pwqm_') or self.p.dataset.startswith('pwqr_') or self.p.dataset.startswith('pwnoq_') :
			if self.p.dataset.startswith('pwq_leq') or self.p.dataset.startswith('pwqs_leq') or self.p.dataset.startswith('pwqf_leq') or self.p.dataset.startswith('pwqfs_leq') or self.p.dataset.startswith('pwqm_leq') or self.p.dataset.startswith('pwqr_leq') or self.p.dataset.startswith('pwnoq_leq'):
				world_assump = self.p.dataset.split('_')[-2]
				maindir      = "_".join(self.p.dataset.split('_')[:-2])
				subdir       = 'pw_' + self.p.dataset.split('_')[-1]  # can be pw_rule, pw_fact ...
			elif(self.p.dataset.startswith('pwq_pararules') and self.p.dataset.startswith('pwq_pararules_leq_3')!=True and self.p.dataset.startswith('pwq_pararules0.1_leq_3')!=True and self.p.dataset.startswith('pwq_pararules0.3_leq_3')!=True and self.p.dataset.startswith('pwq_pararules0.5_leq_3')!=True):
				assert self.p.dataset.startswith('pw_pararules_leq_3') != True
				world_assump = self.p.dataset.split('_')[2]
				maindir      = "_".join(self.p.dataset.split('_')[:2])
				subdir       = 'pw_' + self.p.dataset.split('_')[-1]  # can be pw_rule, pw_fact ...
			elif self.p.dataset.startswith('pwq_pararules_leq_3') or self.p.dataset.startswith('pwq_pararules0.1_leq_3') or self.p.dataset.startswith('pwq_pararules0.3_leq_3') or self.p.dataset.startswith('pwq_pararules0.5_leq_3'):
				world_assump = self.p.dataset.split('_')[4]
				maindir      = "_".join(self.p.dataset.split('_')[:4])
				subdir       = 'pw_' + self.p.dataset.split('_')[-1]  # can be pw_rule, pw_fact ...

			folder = f'../data/processed/{maindir}/{world_assump}/{subdir}/{self.p.arch}/{split}/'
			print(f'data being used from the folder = {folder}')
			dataset = ddict(list)
			if 'reasoner' in self.p.dataset:
				keys = ['input_ids', 'output_ids']
			else:
				keys = ['input_ids', 'token_labels', 'token_mask']
			for key in keys:
				with open(folder + f'{key}.pkl', 'rb') as f:
					print(folder + f'{key}.pkl')
					dataset[key] = pickle.load(f)

		elif self.p.dataset.startswith('pwu_') or self.p.dataset.startswith('pwur_') or self.p.dataset.startswith('pwuc_') or self.p.dataset.startswith('pwurc_'):
			if self.p.dataset.startswith('pwu_leq_') or self.p.dataset.startswith('pwur_leq_') or self.p.dataset.startswith('pwu_B') \
				 or self.p.dataset.startswith('pwu_E') or self.p.dataset.startswith('pwuc_leq_') or self.p.dataset.startswith('pwurc_leq_') \
					 or self.p.dataset.startswith('pwu_pararules') or self.p.dataset.startswith('pwuc_pararules'):

				if (self.p.dataset.startswith('pwu_pararules') or self.p.dataset.startswith('pwuc_pararules')) and '_eq_' in self.p.dataset:
					maindir      = "_".join(self.p.dataset.split('_')[:-1])
					world_assump = self.p.dataset.split('_')[-1]
				elif self.p.dataset.startswith('pwu_B') or self.p.dataset.startswith('pwu_E'):
					maindir      = "_".join(self.p.dataset.split('_')[:2])
					world_assump = self.p.dataset.split('_')[2]
				elif '_eq_' in self.p.dataset and not (self.p.dataset.startswith('pwur_') or self.p.dataset.startswith('pwurc_')):
					maindir      = "_".join(self.p.dataset.split('_')[:5])
					world_assump = self.p.dataset.split('_')[5]
				elif self.p.dataset.startswith('pwur_') or self.p.dataset.startswith('pwurc_'):
					if '_eq_' in self.p.dataset:
						maindir      = "_".join(self.p.dataset.split('_')[:6])
						world_assump = self.p.dataset.split('_')[6]
					else:
						maindir      = "_".join(self.p.dataset.split('_')[:4])
						world_assump = self.p.dataset.split('_')[4]
				else:
					maindir      = "_".join(self.p.dataset.split('_')[:3])
					world_assump = self.p.dataset.split('_')[3]
				folder       = f'../data/processed/{maindir}/{world_assump}/{split}/'
				print(f'data being used from the folder = {folder}')
				dataset = ddict(list)
				keys = ['facts', 'rules', 'ques', 'answer', 'proof', 'equiv_id', 'qdep', 'all_conclusions']
				for key in keys:
					try:
						with open(folder + f'{key}.pkl', 'rb') as f:
							dataset[key] = pickle.load(f)
					except Exception as e:
						assert key == 'all_conclusions'
						print(f'Missing key all_conclusions, skipping!')

			elif self.p.dataset.startswith('pwu_pararules_leq_3') or self.p.dataset.startswith('pwuc_pararules_leq_3'):
				world_assump = self.p.dataset.split('_')[4]
				maindir      = "_".join(self.p.dataset.split('_')[:4])
				folder       = f'../data/processed/{maindir}/{world_assump}/{split}/'
				print(f'data being used from the folder = {folder}')
				dataset = ddict(list)
				keys = ['facts', 'rules', 'ques', 'answer', 'proof', 'equiv_id', 'qdep', 'all_conclusions']
				for key in keys:
					try:
						with open(folder + f'{key}.pkl', 'rb') as f:
							dataset[key] = pickle.load(f)
					except Exception as e:
						assert key == 'all_conclusions'
						print(f'Missing key all_conclusions, skipping!')

		return dataset

	def setup(self, splits='all'):
		self.data = ddict(list)
		if splits == 'all':
			splits = ['train', 'dev', 'test']

		for split in splits:
			if self.p.dataset.startswith('rt_') or self.p.dataset.startswith('rica') or self.p.dataset.startswith('pararules'):
				self.data[split] = RuleTakerDataset(self.load_dataset(split), self.p.pad_idx)

			elif self.p.dataset.startswith('qasc_'):
				self.data[split] = QASCDataset(self.load_dataset(split), self.p.pad_idx)

			elif self.p.dataset.startswith('pw_') and self.p.dataset.endswith('rule'):
				self.data[split] = ProofWriterRuleSelectorDataset(self.load_dataset(split), self.p.pad_idx, self.p.stopcls, self.p.stopsep, self.p.sep_token_id, self.p.cls_token_id)

			elif self.p.dataset.startswith('pw_') and self.p.dataset.endswith('fact'):
				self.data[split] = ProofWriterFactSelectorDataset(self.load_dataset(split), self.p.pad_idx)

			elif (self.p.dataset.startswith('pw_') or self.p.dataset.startswith('pwr_')) and self.p.dataset.endswith('reasoner'):
				self.data[split] = ProofWriterReasonerDataset(self.load_dataset(split), self.p.pad_idx)

			elif self.p.dataset.startswith('pwi_'):
				self.data[split] = ProofWriterReasonerDataset(self.load_dataset(split), self.p.pad_idx)

			elif self.p.dataset.startswith('pwu_') or self.p.dataset.startswith('pwur_') or self.p.dataset.startswith('pwuc_') or self.p.dataset.startswith('pwurc_'):
				self.data[split] = ProofWriterInferenceDataset(self.load_dataset(split))

			elif (self.p.dataset.startswith('pwq_') or self.p.dataset.startswith('pwqs_') or self.p.dataset.startswith('pwqf_') or\
				 	self.p.dataset.startswith('pwqfs_') or self.p.dataset.startswith('pwqm_')) and self.p.dataset.endswith('rule') or\
					self.p.dataset.startswith('pwqr_') or self.p.dataset.startswith('pwnoq_'):
				self.data[split] = ProofWriterRuleSelectorDataset(self.load_dataset(split), self.p.pad_idx, self.p.stopcls, self.p.stopsep, self.p.sep_token_id, self.p.cls_token_id)

			elif self.p.dataset.startswith('pwq_') and self.p.dataset.endswith('fact'):
				self.data[split] = ProofWriterFactSelectorDataset(self.load_dataset(split), self.p.pad_idx)

			elif self.p.dataset.startswith('pwq_') and self.p.dataset.endswith('reasoner'):
				self.data[split] = ProofWriterReasonerDataset(self.load_dataset(split), self.p.pad_idx)

		if self.p.arch == 'lstm':
			# load the word index mapping file
			with open(f'../data/processed/{self.p.dataset}/{self.p.arch}/word2idx.pkl', 'rb') as f:
				self.word2idx = dict(pickle.load(f))

	def train_dataloader(self, shuffle=True):
		return DataLoader(
					self.data['train'],
					batch_size=self.p.train_batch_size,
					num_workers=self.p.num_workers,
					collate_fn=self.data['train'].collater,
					shuffle=shuffle,
					# pin_memory=True
				)

	def val_dataloader(self):
		return DataLoader(
					self.data['dev'],
					batch_size=self.p.eval_batch_size,
					num_workers=self.p.num_workers,
					collate_fn=self.data['dev'].collater,
					# pin_memory=True
				)

	def test_dataloader(self):
		return DataLoader(
					self.data['test'],
					batch_size=self.p.eval_batch_size,
					num_workers=self.p.num_workers,
					collate_fn=self.data['test'].collater,
					# pin_memory=True
				)

	@staticmethod
	def add_data_specific_args(parent_parser):
		parser = ArgumentParser(parents=[parent_parser], add_help=False)
		parser.add_argument("--dataset", 		 				type=str)
		parser.add_argument("--train_dataset",	default='', 	type=str)
		parser.add_argument("--dev_dataset",	default='', 	type=str)
		parser.add_argument("--test_dataset",	default='', 	type=str)
		parser.add_argument("--num_workers", 	default=10, 	type=int)
		return parser


class RuleTakerDataset(Dataset):

	def __init__(self, dataset, pad_idx):
		self.data    = dataset
		self.pad_idx = pad_idx

	def __len__(self):
		return len(self.data['label'])

	def __getitem__(self, idx):
		item = {
			'sent'    : torch.LongTensor(self.data['input_ids'][idx]),
			'ctx'     : torch.LongTensor(self.data['ctx_ids'][idx]),
			'stmt'    : torch.LongTensor(self.data['stmt_ids'][idx]),
			'lbl'     : torch.FloatTensor([self.data['label'][idx]]),
			'ctx_len' : torch.LongTensor([len(self.data['ctx_ids'][idx])]),
			'stmt_len': torch.LongTensor([len(self.data['stmt_ids'][idx])]),
		}

		return item

	def collater(self, items):
		all_sents = pad_sequence([x['sent'] for x in items], batch_first=True, padding_value=self.pad_idx)
		batch = {
			'all_sents' : all_sents,
			'all_ctxs'  : pad_sequence([x['ctx'] for x in items], batch_first=True, padding_value=self.pad_idx),
			'all_stmts' : pad_sequence([x['stmt'] for x in items], batch_first=True, padding_value=self.pad_idx),
			'all_lbls'  : torch.cat([x['lbl'] for x in items]),
			'ctx_lens'  : torch.cat([x['ctx_len'] for x in items]),
			'stmt_lens' : torch.cat([x['stmt_len'] for x in items]),
			'attn_mask' : (all_sents != self.pad_idx).long(),
		}

		return batch


class QASCDataset(Dataset):

	def __init__(self, dataset, pad_idx):
		self.data    = dataset
		self.pad_idx = pad_idx

	def __len__(self):
		return len(self.data['label'])

	def __getitem__(self, idx):
		item = {
			'input_ids': [torch.LongTensor(x) for x in self.data['input_ids'][idx]],
			'type_ids' : [torch.LongTensor(x) for x in self.data['type_ids'][idx]],
			'lbl'      : torch.FloatTensor([self.data['label'][idx]]),
		}

		return item

	def collater(self, items):
		all_inps         = [x for y in items for x in y['input_ids']]
		all_types        = [x for y in items for x in y['type_ids']]
		all_inps_padded  = pad_sequence(all_inps, batch_first=True, padding_value=self.pad_idx).reshape(len(items), 8, -1)


		batch = {
			'all_inps_padded'  : all_inps_padded,
			'all_types_padded' : pad_sequence(all_types, batch_first=True, padding_value=self.pad_idx).reshape(len(items), 8, -1),
			'all_lbls'         : torch.cat([x['lbl'] for x in items]),
			'attn_mask'        : (all_inps_padded != self.pad_idx).long()
		}

		return batch


class ProofWriterRuleSelectorDataset(Dataset):

	def __init__(self, dataset, pad_idx, stopcls, stopsep, sep_token_id, cls_token_id):
		self.data    = dataset
		self.pad_idx = pad_idx
		self.stopcls = stopcls
		self.stopsep = stopsep
		self.sep_token_id = sep_token_id
		self.cls_token_id = cls_token_id

	def __len__(self):
		return len(self.data['token_labels']) # total number of datapoints

	def __getitem__(self, idx):
		item = {
			'sent'        : torch.LongTensor(self.data['input_ids'][idx]),
			'token_labels': torch.FloatTensor([self.data['token_labels'][idx]]),
			'token_mask'  : torch.FloatTensor([self.data['token_mask'][idx]]),
		}
		if(self.stopcls):
			assert(self.stopsep == False)
			# commented the following line in favour of pwqm dataset
			# assert(item['token_labels'].sum()<=1) # at max only one rule is selected
			assert(item['sent'][0].item() == self.cls_token_id) # is the first token actually the cls token or not
			item['token_mask'][0][0] = 1.
			if(item['token_labels'].sum()==0):
				item['token_labels'][0][0] = 1.

		if(self.stopsep):
			assert(self.stopcls == False)
			# commented the following line in favour of pwqm dataset
			# assert(item['token_labels'].sum()<=1) # at max only one rule is selected
			assert(item['sent'][-1].item() == self.sep_token_id) # is the last token actually the sep token or not
			item['token_mask'][0][-1] = 1.
			if(item['token_labels'].sum()==0):
				item['token_labels'][0][-1] = 1.

		return item

	def collater(self, items):
		# note for sent, the pad value can be 1/0 based on roberta/bert
		# for token labels and token masks, the pad value must be 0
		all_sents = pad_sequence([x['sent'] for x in items], batch_first=True, padding_value=self.pad_idx)
		batch = {
			'all_sents'       : all_sents,
			'all_token_labels': pad_sequence([x['token_labels'].squeeze() for x in items], batch_first=True, padding_value=0),
			'all_token_mask'  : pad_sequence([x['token_mask'].squeeze() for x in items], batch_first=True, padding_value=0),
			'attn_mask'       : (all_sents != self.pad_idx).long(),
		}

		return batch

class ProofWriterFactSelectorDataset(Dataset):

	def __init__(self, dataset, pad_idx):
		self.data    = dataset
		self.pad_idx = pad_idx

	def __len__(self):
		return len(self.data['token_labels'])

	def __getitem__(self, idx):
		item = {
			'sent'    : torch.LongTensor(self.data['input_ids'][idx]),
			'token_labels'     : torch.FloatTensor([self.data['token_labels'][idx]]),
			'token_mask' : torch.FloatTensor([self.data['token_mask'][idx]]),
		}
		return item

	def collater(self, items):
		all_sents = pad_sequence([x['sent'] for x in items], batch_first=True, padding_value=self.pad_idx)
		batch = {
			'all_sents'       : all_sents,
			'all_token_labels': pad_sequence([x['token_labels'].squeeze() for x in items], batch_first=True, padding_value=0),
			'all_token_mask'  : pad_sequence([x['token_mask'].squeeze() for x in items], batch_first=True, padding_value=0),
			'attn_mask'       : (all_sents != self.pad_idx).long(),
		}

		return batch

class ProofWriterReasonerDataset(Dataset):

	def __init__(self, dataset, pad_idx):
		self.data    = dataset
		self.pad_idx = pad_idx

	def __len__(self):
		return len(self.data['input_ids'])

	def __getitem__(self, idx):
		item = {
			'input'   : torch.LongTensor(self.data['input_ids'][idx]),
			'output'  : torch.LongTensor(self.data['output_ids'][idx]),
		}

		return item

	def collater(self, items):
		all_inps        = pad_sequence([x['input'] for x in items], batch_first=True, padding_value=self.pad_idx)
		all_outs        = pad_sequence([x['output'] for x in items], batch_first=True, padding_value=self.pad_idx)
		y_ids           = all_outs[:, :-1].contiguous()
		labels          = all_outs[:, 1:].clone()
		labels[all_outs[:, 1:] == self.pad_idx] = -100

		batch = {
			'all_inps' : all_inps,
			'all_outs' : all_outs,
			'attn_mask': (all_inps != self.pad_idx).long(),
			'y_ids'    : y_ids,
			'labels'   : labels,
		}

		return batch

class ProofWriterInferenceDataset(Dataset):

	def __init__(self, dataset):
		self.data = dataset

	def __len__(self):
		return len(self.data['answer'])

	def __getitem__(self, idx):
		# print(self.data['qdep'], len(self.data['qdep']), idx)
		item = {
			'facts'          : self.data['facts'][idx],
			'rules'          : self.data['rules'][idx],
			'ques'           : self.data['ques'][idx],
			'answer'         : self.data['answer'][idx],
			'proof'          : self.data['proof'][idx],
			'equiv_id'       : self.data['equiv_id'][idx],
			'qdep'           : self.data['qdep'][idx],
		}

		if 'all_conclusions' in self.data and len(self.data['all_conclusions']):
			item['all_conclusions'] = self.data['all_conclusions'][idx]
		else:
			item['all_conclusions'] = None

		return item

	def collater(self, items):
		all_facts       = [x['facts'] for x in items]
		all_rules       = [x['rules'] for x in items]
		all_ques        = [x['ques'] for x in items]
		all_answer      = [x['answer'] for x in items]
		all_proof       = [x['proof'] for x in items]
		all_equiv_id    = [x['equiv_id'] for x in items]
		all_qdep        = [x['qdep'] for x in items]
		all_conclusions = [x['all_conclusions'] for x in items]

		batch = {
			'all_facts'      : all_facts,
			'all_rules'      : all_rules,
			'all_ques'       : all_ques,
			'all_answer'     : all_answer,
			'all_proof'      : all_proof,
			'all_equiv_id'   : all_equiv_id,
			'all_qdep'       : all_qdep,
			'all_conclusions': all_conclusions,
		}

		return batch
