from helper import *
from basemodel import BaseModel


class ProofWriterRuleSelector(BaseModel):
	def __init__(self, arch='roberta_large', train_batch_size=16, eval_batch_size=16, accumulate_grad_batches=1, learning_rate=1e-5, max_epochs=5,\
					optimizer='adamw', adam_epsilon=1e-8, weight_decay=0.0, lr_scheduler='linear_with_warmup', warmup_updates=0.0, freeze_epochs=-1, gpus=1,\
					hf_name='roberta-large', celoss=False, select_thresh=0., stopcls=False, stopsep=False, multitask=False, bert_init=False, cls_dropout=0.1, \
					use_sigmoid=False, topk=0, cls_thresh=0.5, num_logit_layers=0):
		super().__init__(train_batch_size=train_batch_size, max_epochs=max_epochs, gpus=gpus)
		self.save_hyperparameters()
		assert arch == 'roberta_large'

		self.p                         = types.SimpleNamespace()
		self.p.arch                    = arch
		self.p.train_batch_size        = train_batch_size
		self.p.eval_batch_size         = eval_batch_size
		self.p.accumulate_grad_batches = accumulate_grad_batches
		self.p.learning_rate           = learning_rate
		self.p.max_epochs              = max_epochs
		self.p.optimizer               = optimizer
		self.p.adam_epsilon            = adam_epsilon
		self.p.weight_decay            = weight_decay
		self.p.lr_scheduler            = lr_scheduler
		self.p.warmup_updates          = warmup_updates
		self.p.freeze_epochs           = freeze_epochs
		self.p.gpus                    = gpus

		self.p.celoss                  = celoss # for using cross entopy loss in order to select just 1 rule (or stop, in case we are using data with stopping points)
		self.p.rule_thresh             = select_thresh # for having a threshold (other than 0) for selecting rules in the case where we are using sigmoid to select rules
		self.p.stopcls                 = stopcls
		self.p.stopsep                 = stopsep
		self.p.multitask               = multitask
		self.p.bert_init               = bert_init
		self.p.cls_dropout             = cls_dropout
		self.p.use_sigmoid			   = use_sigmoid
		self.p.topk					   = topk
		self.p.cls_thresh			   = cls_thresh
		self.p.num_logit_layers		   = num_logit_layers

		self.text_encoder    = AutoModel.from_pretrained(hf_name)
		self.tokenizer       = AutoTokenizer.from_pretrained(hf_name)
		out_dim              = self.text_encoder.config.hidden_size
		self.out_dim         = out_dim

		if self.p.num_logit_layers>0:
			self.classifier_layers = nn.ModuleList([nn.Linear(out_dim, out_dim) for i in range(self.p.num_logit_layers)])

		self.classifier      = nn.Linear(out_dim, 1)
		self.dropout         = torch.nn.Dropout(self.p.cls_dropout)

		if self.p.multitask:
			self.stop_classifier = nn.Linear(out_dim, 1)

		self.initialize()

	def initialize(self):
		if self.p.bert_init:
			self.classifier.weight.data.normal_(mean=0.0, std=self.text_encoder.config.initializer_range)
		else:
			xavier_normal_(self.classifier.weight)
		self.classifier.bias.data.zero_()

		if self.p.multitask:
			if self.p.bert_init:
				self.stop_classifier.weight.data.normal_(mean=0.0, std=self.text_encoder.config.initializer_range)
			else:
				xavier_normal_(self.stop_classifier.weight)
			self.stop_classifier.bias.data.zero_()

	def forward(self, input_ids, attn_mask):
		last_hidden_state = self.text_encoder(input_ids=input_ids, attention_mask=attn_mask)['last_hidden_state'] #shape (batchsize, seqlen, hiddensize)
		last_hidden_state = self.dropout(last_hidden_state)
		if self.p.multitask:
			stop_state, token_state = last_hidden_state[:, 0, :], last_hidden_state[:, 1:, :]
			stop_logits             = self.stop_classifier(stop_state).squeeze()
			token_logits            = self.classifier(token_state).squeeze()
			# TODO: We can also merge these here and take combined loss
			return token_logits, stop_logits
		else:
			# for deep FFN as a classifier
			if self.p.num_logit_layers>0:
				for i, layer in enumerate(self.classifier_layers):
					last_hidden_state = F.relu(layer(last_hidden_state))

			logits = self.classifier(last_hidden_state).squeeze()
			return logits

	def predict(self, input_ids, token_mask, attn_mask, stop_priority='rule'):
		device  = input_ids.device
		outputs = self(input_ids, attn_mask)

		if self.p.multitask:
			token_logits, stop_logits = outputs
			logits = torch.cat([stop_logits.unsqueeze(1), token_logits], dim=1)
		else:
			logits = outputs

		# First filter out the logits corresponding to the valid tokens
		mask_len          = token_mask.sum(1) # (batchsize) eg [8,3,2,1]
		mask_nonzero      = torch.nonzero(token_mask) # (z, 2) size tensor, having x, y coordinates of non zero elements. z = no. of non zero elements
		y_indices         = torch.cat([torch.arange(x) for x in mask_len]).to(device)
		x_indices         = mask_nonzero[:, 0]
		filtered_logits   = torch.full((input_ids.shape[0], mask_len.max()), -1000.0).to(device)
		filtered_logits[x_indices, y_indices] = torch.masked_select(logits, token_mask.bool())

		# Then compute the predictions for each of the logit
		if self.p.celoss:
			argmax_filtered_logits	= torch.argmax(filtered_logits, dim=1)
			preds 					= (F.one_hot(argmax_filtered_logits, num_classes=filtered_logits.shape[1])).int()

		elif self.p.topk != 0:
			assert not self.p.celoss
			topk_filtered_logits, _  	  = torch.topk(filtered_logits, min(self.p.topk, filtered_logits.shape[1]), dim=1, sorted=False)
			min_of_topk, _			  	  = torch.min(topk_filtered_logits, dim = 1)
			min_of_topk					  = min_of_topk.reshape(topk_filtered_logits.shape[0], 1)
			preds 					  	  = torch.ge(filtered_logits, min_of_topk).int()

		else:
			if self.p.use_sigmoid:
				## having different thresholds for cls and rules in case --stopcls is True
				## note: different thresholds this can be used if --use_sigmoid is True
				if self.p.stopcls:
					stop_logits, token_logits = filtered_logits[:, 0], filtered_logits[:, 1:]
					stop_logits_preds, token_logits_preds = torch.sigmoid(stop_logits) > self.p.cls_thresh, torch.sigmoid(token_logits) > self.p.rule_thresh
					preds = torch.cat([stop_logits_preds.unsqueeze(1), token_logits_preds], dim=1)
				else:
					preds = (torch.sigmoid(filtered_logits) > self.p.rule_thresh)

			else:
				preds = (filtered_logits > self.p.rule_thresh)

			# preds_old = (outputs > 0.)
			# print(filtered_logits, torch.sigmoid(filtered_logits))
			# print(preds, preds_old)
			# print(preds.shape, filtered_logits.shape, preds_old.shape, torch.sigmoid(filtered_logits).shape)
			# import pdb; pdb.set_trace()
			# assert(preds==preds_old)

		if self.p.stopcls:
			if stop_priority == 'rule':
				# truncating preds to remove the cls token predictions (prioritize rule selection)
				preds = preds[:, 1:]
			elif stop_priority == 'cls':
				# zero any rule selection if cls prediciton is positive (priority to stop selection)
				stop_preds, preds = preds[:, 0], preds[:, 1:]
				preds = (1 - stop_preds).unsqueeze(1) * preds
			else:
				raise NotImplementedError

		if self.p.stopsep:
			try:
				all_logits_binary_mask 	  					 	 = torch.zeros_like(filtered_logits).to(device)
				all_logits_binary_mask[x_indices, y_indices] 	 = 1
				idx 		  = torch.arange(all_logits_binary_mask.shape[1]).to(device)
				tmp 		  = all_logits_binary_mask * idx
				sep_y_indices = torch.argmax(tmp, 1)
				sep_x_indices = torch.arange(all_logits_binary_mask.shape[0]).to(device)
				preds[sep_x_indices, sep_y_indices] = 0

			except Exception as e:
				print('************Exception************')
				print(traceback.format_exc())
				sys.stdout = sys.__stdout__; import pdb; pdb.set_trace()


		# Finally, save a padded rule matrix with indices of the rules and the corresponding mask
		pred_mask_lengths = preds.sum(1)
		pred_mask_nonzero = torch.nonzero(preds)
		y_indices         = torch.cat([torch.arange(x) for x in pred_mask_lengths]).to(device)
		x_indices         = pred_mask_nonzero[:, 0]
		filtered_rule_ids = torch.full((input_ids.shape[0], pred_mask_lengths.max()), -1).to(device)
		filtered_rule_ids[x_indices, y_indices] = pred_mask_nonzero[:, 1]
		filtered_mask     = (filtered_rule_ids != -1)

		# Make the -1's -> 0 so that we can select some rule. Given the mask we can always prune this later
		# This step is non-intuitive here. To understand this, we need to consider the for loop in the main decoding logic where the rule_ids are used.
		filtered_rule_ids[~filtered_mask] = 0

		# filtered_rule_ids of size (b*maxrule_ids)
		return filtered_rule_ids, filtered_mask

	def calc_loss(self, outputs, targets, token_mask):
		if self.p.celoss:
			# all rows of target are one hot, ie there is only 1 rule that needs to be selected
			assert torch.all(torch.sum(targets * token_mask, dim=1) == torch.ones(targets.shape[0]).to(targets.device))
			exp_logits = torch.exp(outputs)
			assert exp_logits.shape == token_mask.shape
			masked_exp_logits = exp_logits * token_mask
			norm_masked_exp_logits = masked_exp_logits/torch.sum(masked_exp_logits, dim=1).unsqueeze(-1)
			# convert the 0's to 1 in norm_masked_exp_logits so that log makes it 0
			# can be done by setting those indexes in norm_masked_exp_logits to 1., where token_mask = 0
			zeros_mask = (1 - token_mask).bool() #ones_mask is token_mask with 0 and 1 exchanged
			norm_masked_exp_logits[zeros_mask] = 1. # setting those indices of norm_mask = 1. where ones_mask = 1

			# handling log(0) --> log(small_value) for places where token_mask is 1 and norm_mask is 0
			zeros_mask_ = (norm_masked_exp_logits == 0)
			norm_masked_exp_logits[zeros_mask] = 1e-8

			logvals = torch.log(norm_masked_exp_logits)
			logvals_detached = logvals.detach().clone()
			# no value of logvals should be nan/inf/-inf, else we would have to handle this by adding 1e-8 to some elements of the norm_masked tensor
			assert (not torch.any(torch.isnan(logvals_detached))) and (not torch.any(torch.isinf(logvals_detached))) and (not torch.any(torch.isneginf(logvals_detached)))
			assert torch.sum((logvals!=0)*targets).item() <= torch.sum(targets).item() #The number of non zeros in logvals is <= the number of 1's in the target.
			loss_reduced = F.nll_loss(logvals, torch.nonzero(targets)[:, 1], reduction='mean')
		else:
			loss_not_reduced = F.binary_cross_entropy_with_logits(outputs, targets, reduction='none')
			loss_masked      = loss_not_reduced * token_mask
			loss_reduced     = loss_masked.sum() / token_mask.sum()

		return loss_reduced

	def calc_acc(self, preds, targets, token_mask):
		acc_not_reduced = (preds == targets).float()
		acc_masked      = torch.mul(acc_not_reduced, token_mask)
		acc_reduced     = acc_masked.sum()/token_mask.sum()
		acc             = 100 * acc_reduced
		return acc

	def calc_F1(self, preds, targets, token_mask):
		'''calculates the binary F1 score between preds and targets, with positive class being 1'''
		assert preds.shape == targets.shape
		assert preds.shape == token_mask.shape

		# get only the relevant indices of preds and targets, ie those which are non zero in token_mask
		mask           = (token_mask == 1)
		preds_masked   = torch.masked_select(preds, mask).cpu()
		targets_masked = torch.masked_select(targets, mask).cpu()

		binary_f1_class1 = f1_score(y_true=targets_masked, y_pred=preds_masked, pos_label=1, average='binary')
		binary_f1_class0 = f1_score(y_true=targets_masked, y_pred=preds_masked, pos_label=0, average='binary')
		macro_f1         = f1_score(y_true=targets_masked, y_pred=preds_masked, average='macro')
		micro_f1         = f1_score(y_true=targets_masked, y_pred=preds_masked, average='micro')

		return {'f1_class1':binary_f1_class1, 'f1_class0':binary_f1_class0, 'macro_f1':macro_f1, 'micro_f1':micro_f1}

	def calc_perf_metrics(self, preds, targets, token_mask):
		acc       = self.calc_acc(preds, targets, token_mask)
		F1_scores = self.calc_F1(preds, targets, token_mask)

		return {'acc':acc, 'f1_class1':F1_scores['f1_class1'], 'f1_class0':F1_scores['f1_class0'], 'macro_f1':F1_scores['macro_f1'], 'micro_f1':F1_scores['micro_f1']}

	def run_step(self, batch, split):
		outputs    = self(batch['all_sents'], batch['attn_mask'])
		token_mask = batch['all_token_mask']
		targets    = batch['all_token_labels']

		if self.p.celoss:
			relevant_outputs        = outputs * token_mask
			argmax_relevant_outputs = torch.argmax(relevant_outputs, dim=1)
			loss                    = self.calc_loss(outputs.squeeze(), targets.squeeze(), token_mask.squeeze())
			preds                   = (F.one_hot(argmax_relevant_outputs, num_classes=outputs.shape[1])).int()
		elif self.p.multitask:
			token_logits, stop_logits   = outputs
			stop_mask, new_token_mask   = token_mask[:, 0], token_mask[:, 1:]
			stop_targets, token_targets = targets[:, 0], targets[:, 1:]
			assert sum(stop_mask).item() == len(stop_mask)

			stop_loss  = self.calc_loss(stop_logits, stop_targets, stop_mask)
			token_loss = self.calc_loss(token_logits, token_targets, new_token_mask)
			loss       = stop_loss + token_loss
			combined   = torch.cat([stop_logits.unsqueeze(1), token_logits], dim=1)

			if(self.p.use_sigmoid):
				## having different thresholds for cls and rules in case --stopcls is True
				## note: different thresholds this can be used if --use_sigmoid is True
				if self.p.stopcls:
					stop_logits, token_logits = combined[:, 0], combined[:, 1:]
					stop_logits_preds, token_logits_preds = (torch.sigmoid(stop_logits) > self.p.cls_thresh).float().squeeze(), (torch.sigmoid(token_logits) > self.p.rule_thresh).float().squeeze()
					preds = torch.cat([stop_logits_preds.unsqueeze(1), token_logits_preds], dim=1)
				else:
					preds = (torch.sigmoid(combined) > self.p.rule_thresh).float().squeeze()
			else:
				preds = (combined > self.p.rule_thresh).float().squeeze()
		else:
			if(self.p.use_sigmoid):
				## having different thresholds for cls and rules in case --stopcls is True
				## note: different thresholds this can be used if --use_sigmoid is True
				if self.p.stopcls:
					stop_logits, token_logits = outputs[:, 0], outputs[:, 1:]
					stop_logits_preds, token_logits_preds = (torch.sigmoid(stop_logits) > self.p.cls_thresh).float().squeeze(), (torch.sigmoid(token_logits) > self.p.rule_thresh).float().squeeze()
					preds = torch.cat([stop_logits_preds.unsqueeze(1), token_logits_preds], dim=1)
				else:
					preds = (torch.sigmoid(outputs) > self.p.rule_thresh).float().squeeze()

			else:
				preds = (outputs > self.p.rule_thresh).float().squeeze()
			loss  = self.calc_loss(outputs.squeeze(), targets.squeeze(), token_mask.squeeze())

		perf_metrics = self.calc_perf_metrics(preds.squeeze(), targets.squeeze(), token_mask.squeeze())

		if split == 'train':
			self.log(f'train_loss_step', loss.item(), prog_bar=True)
			for metric in perf_metrics.keys():
				self.log(f'train_{metric}_step', perf_metrics[metric], prog_bar=True)
		else:
			self.log(f'{split}_loss_step', loss.item(), prog_bar=True, sync_dist=True)
			for metric in perf_metrics.keys():
				self.log(f'{split}_{metric}_step', perf_metrics[metric], prog_bar=True)

		return {'loss': loss, 'preds': preds, 'targets': targets, 'token_mask': token_mask}

	def aggregate_epoch(self, outputs, split):
		preds        = torch.cat([x['preds'].reshape(-1) for x in outputs])
		targets      = torch.cat([x['targets'].reshape(-1) for x in outputs])
		token_mask   = torch.cat([x['token_mask'].reshape(-1) for x in outputs])
		loss         = torch.stack([x['loss'] for x in outputs]).mean()
		perf_metrics = self.calc_perf_metrics(preds.squeeze(), targets.squeeze(), token_mask.squeeze())

		if split == 'train':
			self.log(f'train_loss_epoch', loss.item())
			for metric in perf_metrics.keys():
				self.log(f'train_{metric}_epoch', perf_metrics[metric], prog_bar=True)
		elif split == 'valid':
			self.log(f'valid_loss_epoch', loss.item(), sync_dist=True)
			for metric in perf_metrics.keys():
				self.log(f'valid_{metric}_epoch', perf_metrics[metric], prog_bar=True)
		elif split == 'test':
			self.log(f'test_loss_epoch', loss.item(), sync_dist=True)
			for metric in perf_metrics.keys():
				self.log(f'test_{metric}_epoch', perf_metrics[metric], prog_bar=True)
			self.predictions = torch.stack((preds, targets), dim=1)
			print('predictions tensor in ruletaker class, shape = {}'.format(self.predictions.shape))

	def configure_optimizers(self):
		no_decay = ['bias', 'LayerNorm.weight']
		optimizer_grouped_parameters = [
			{
				'params'      : [p for n, p in self.text_encoder.named_parameters() if not any(nd in n for nd in no_decay)],
				'weight_decay': self.p.weight_decay,
			},
			{
				'params'      : [p for n, p in self.text_encoder.named_parameters() if any(nd in n for nd in no_decay)],
				'weight_decay': 0.0,
			}
		]

		optimizer_grouped_parameters += [
			{
				'params'      : [p for n, p in self.classifier.named_parameters() if not any(nd in n for nd in no_decay)],
				'weight_decay': self.p.weight_decay,
			},
			{
				'params'      : [p for n, p in self.classifier.named_parameters() if any(nd in n for nd in no_decay)],
				'weight_decay': 0.0,
			}
		]

		if self.p.optimizer == 'adamw':
			optimizer = AdamW(optimizer_grouped_parameters, lr=self.p.learning_rate, eps=self.p.adam_epsilon, betas=[0.9, 0.98])
		else:
			raise NotImplementedError

		if self.p.lr_scheduler == 'linear_with_warmup':
			if self.p.warmup_updates > 1.0:
				warmup_steps = int(self.p.warmup_updates)
			else:
				warmup_steps = int(self.total_steps * self.p.warmup_updates)
			print(f'\nTotal steps: {self.total_steps} with warmup steps: {warmup_steps}\n')

			scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=warmup_steps, num_training_steps=self.total_steps)
			scheduler = {
				'scheduler': scheduler,
				'interval': 'step',
				'frequency': 1
			}
		elif self.p.lr_scheduler == 'fixed':
			return [optimizer]
		else:
			raise NotImplementedError

		return [optimizer], [scheduler]
