from helper import *
from data import DataModule
from proofwriter_ruleselector_model import ProofWriterRuleSelector
from proofwriter_factselector_model import ProofWriterFactSelector
from proofwriter_reasoner_model import ProofWriterReasoner
from proof_inference import ProofWriterInference

model_dict = {
	'proofwriter_ruleselector'      : ProofWriterRuleSelector,
	'proofwriter_factselector'      : ProofWriterFactSelector,
	'proofwriter_reasoner'          : ProofWriterReasoner,
	'proofwriter_inference'         : ProofWriterInference,
}

# This is now also dynamically modified in main function
monitor_dict = {
	'proofwriter_ruleselector'      : ('valid_macro_f1_epoch', 'max'),
	'proofwriter_factselector'      : ('valid_macro_f1_epoch', 'max'),
	'proofwriter_reasoner'          : ('valid_acc_epoch', 'max'),
	'proofwriter_inference'         : ('valid_acc_epoch', 'max'),
}


def generate_hydra_overrides():
	# TODO find a better way to override if possible? Maybe we need to use @hydra.main() for this?
	parser = ArgumentParser()
	parser.add_argument('--override')	# Overrides the default hydra config. Setting order is not fixed. E.g., --override rtx_8000,fixed
	args, _ = parser.parse_known_args()

	overrides = []
	if args.override is not None:
		groups = [x for x in os.listdir('./configs/') if os.path.isdir('./configs/' + x)]
		# print(groups)
		for grp in groups:
			confs = [x.replace('.yaml', '') for x in os.listdir('./configs/' + grp) if os.path.isfile('./configs/' + grp + '/' + x)]
			for val in args.override.split(','):
				if val in confs:
					overrides.append(f'{grp}={val}')

	return parser, overrides

def load_hydra_cfg(overrides):
	initialize(config_path="./configs/")
	cfg = compose("config", overrides=overrides)
	print('Composed hydra config:\n\n', OmegaConf.to_yaml(cfg))

	return cfg

def parse_args(args=None):
	override_parser, overrides = generate_hydra_overrides()
	hydra_cfg                  = load_hydra_cfg(overrides)
	defaults                   = dict()
	for k,v in hydra_cfg.items():
		if type(v) == DictConfig:
			defaults.update(v)
		else:
			defaults.update({k: v})

	parser = argparse.ArgumentParser(parents=[override_parser], add_help=False)
	parser = pl.Trainer.add_argparse_args(parser)
	parser = model_dict[defaults['model']].add_model_specific_args(parser)
	parser = DataModule.add_data_specific_args(parser)

	parser.add_argument('--seed', 				default=42, 					type=int,)
	parser.add_argument('--name', 				default='test', 				type=str,)
	parser.add_argument('--log_db', 			default='manual_runs', 			type=str,)
	parser.add_argument('--tag_attrs', 			default='model,dataset,arch', 	type=str,)
	parser.add_argument('--ckpt_path', 			default='', 					type=str,)
	parser.add_argument('--eval_splits', 		default='', 					type=str,)
	parser.add_argument('--debug', 				action='store_true')
	parser.add_argument('--offline', 			action='store_true')
	parser.add_argument('--save_checkpoint', 	action='store_true')
	parser.add_argument('--resume_training', 	action='store_true')
	parser.add_argument('--evaluate_ckpt', 		action='store_true')
	parser.add_argument('--save_predictions', 	default=None, 					type=str,)
	parser.add_argument('--restore_configs', 	action='store_true')

	parser.set_defaults(**defaults)

	return parser.parse_args()

def get_callbacks(args):

	monitor, mode = monitor_dict[args.model]

	checkpoint_callback = ModelCheckpoint(
		monitor=monitor,
		dirpath=os.path.join(args.root_dir, 'checkpoints'),
		save_top_k=1,
		mode=mode,
		verbose=True,
		save_last=False,
	)

	early_stop_callback = EarlyStopping(
		monitor=monitor,
		min_delta=0.00,
		patience=5,
		verbose=False,
		mode=mode
	)

	return [checkpoint_callback, early_stop_callback]

def restore_config_params(model, args):
	# TODO Maybe we might require this in future if we want to overwrite some of the ckpt-loaded params
	# restores some of the model args to those of config args
	if args.restore_configs:
		print('*************Restore configs is True******************')
		if args.model == 'proofwriter_inference':
			model.rule_selector.p.stopcls     = args.stopcls
			model.rule_selector.p.stopsep     = args.stopsep
			model.rule_selector.p.celoss      = args.celoss
			model.rule_selector.p.rule_thresh = args.rule_thresh
			model.rule_selector.p.use_sigmoid = args.use_sigmoid
			model.rule_selector.p.topk		  = args.topk

			model.fact_selector.p.fact_thresh = args.fact_thresh
			model.fact_selector.p.use_sigmoid = args.use_sigmoid

	return model


def main(args, splits='all'):
	pl.seed_everything(args.seed)
	# import pdb; pdb.set_trace()


	if args.debug:
		# for DEBUG purposes only
		args.limit_train_batches = 10
		args.limit_val_batches   = 10
		args.limit_test_batches  = 10
		args.max_epochs          = 2
		# for DEBUG purposes only


	args.root_dir       = f'../saved/{args.name}'
	if(args.model == 'proofwriter_inference'):
		os.mkdir(args.root_dir)
	print(f'Saving to {args.root_dir}')

	print('Building trainer...')
	trainer	= pl.Trainer.from_argparse_args(
		args,
		callbacks=get_callbacks(args),
		num_sanity_val_steps=0,
	)



	# used for obtaining the sep and cls token id and passing it to the data module
	if(args.model == 'proofwriter_ruleselector'):
		# this is for the asserts we have put in the dataloader class of the ruleselector model
		tokenizer    = AutoTokenizer.from_pretrained(args.hf_name)
		sep_token_id = tokenizer.sep_token_id
		cls_token_id = tokenizer.cls_token_id
	else:
		sep_token_id = -1
		cls_token_id = -1

	# print(f'Loading {args.dataset} dataset --> Train: {args.train_dataset} Valid: {args.dev_dataset} Test: {args.test_dataset}')
	print(f'Loading {args.dataset} dataset')
	dm = DataModule(
			args.dataset,
			args.train_dataset,
			args.dev_dataset,
			args.test_dataset,
			args.arch,
			train_batch_size=args.train_batch_size,
			eval_batch_size=args.eval_batch_size,
			num_workers=args.num_workers,
			pad_idx=args.padding,
			stopcls=args.stopcls,
			stopsep=args.stopsep,
			sep_token_id = sep_token_id,
			cls_token_id = cls_token_id,
		)
	dm.setup(splits=splits)

	print(f'Loading {args.model} - {args.arch} model...')
	if args.model == 'proofwriter_ruleselector':
		model = model_dict[args.model](
				arch=args.arch,
				train_batch_size=args.train_batch_size,
				eval_batch_size=args.eval_batch_size,
				accumulate_grad_batches=args.accumulate_grad_batches,
				learning_rate=args.learning_rate,
				max_epochs=args.max_epochs,
				optimizer=args.optimizer,
				adam_epsilon=args.adam_epsilon,
				weight_decay=args.weight_decay,
				lr_scheduler=args.lr_scheduler,
				warmup_updates=args.warmup_updates,
				freeze_epochs=args.freeze_epochs,
				gpus=args.gpus,
				hf_name=args.hf_name,
				celoss=args.celoss,
				select_thresh=args.rule_thresh,
				stopcls=args.stopcls,
				stopsep=args.stopsep,
				multitask=args.multitask,
				bert_init=args.bert_init,
				cls_dropout=args.cls_dropout,
				use_sigmoid=args.use_sigmoid,
				topk = args.topk,
				cls_thresh = args.cls_thresh,
				num_logit_layers = args.num_logit_layers,
			)

	elif args.model == 'proofwriter_factselector':
		model = model_dict[args.model](
				arch=args.arch,
				train_batch_size=args.train_batch_size,
				eval_batch_size=args.eval_batch_size,
				accumulate_grad_batches=args.accumulate_grad_batches,
				learning_rate=args.learning_rate,
				max_epochs=args.max_epochs,
				optimizer=args.optimizer,
				adam_epsilon=args.adam_epsilon,
				weight_decay=args.weight_decay,
				lr_scheduler=args.lr_scheduler,
				warmup_updates=args.warmup_updates,
				freeze_epochs=args.freeze_epochs,
				gpus=args.gpus,
				hf_name=args.hf_name,
				select_thresh=args.fact_thresh,
				use_sigmoid=args.use_sigmoid,
			)

	elif args.model == 'proofwriter_iterative' or args.model == 'proofwriter_reasoner':
		model = model_dict[args.model](
				arch=args.arch,
				train_batch_size=args.train_batch_size,
				eval_batch_size=args.eval_batch_size,
				accumulate_grad_batches=args.accumulate_grad_batches,
				learning_rate=args.learning_rate,
				max_epochs=args.max_epochs,
				optimizer=args.optimizer,
				adam_epsilon=args.adam_epsilon,
				weight_decay=args.weight_decay,
				lr_scheduler=args.lr_scheduler,
				warmup_updates=args.warmup_updates,
				freeze_epochs=args.freeze_epochs,
				gpus=args.gpus,
				hf_name=args.hf_name,
			)

	elif args.model == 'proofwriter_iterative_eval':
		model = model_dict[args.model](
				train_batch_size=args.train_batch_size,
				eval_batch_size=args.eval_batch_size,
				accumulate_grad_batches=args.accumulate_grad_batches,
				max_epochs=args.max_epochs,
				gpus=args.gpus,
				hf_name=args.hf_name,
				root_dir=args.root_dir,
			)

	elif args.model == 'proofwriter_inference':
		model = model_dict[args.model](
				ruleselector_ckpt=args.ruleselector_ckpt,
				factselector_ckpt=args.factselector_ckpt,
				reasoner_ckpt=args.reasoner_ckpt,
				ques_augmented=args.ques_augmented,
				arch=args.arch,
				train_batch_size=args.train_batch_size,
				eval_batch_size=args.eval_batch_size,
				accumulate_grad_batches=args.accumulate_grad_batches,
				learning_rate=args.learning_rate,
				max_epochs=args.max_epochs,
				optimizer=args.optimizer,
				adam_epsilon=args.adam_epsilon,
				weight_decay=args.weight_decay,
				lr_scheduler=args.lr_scheduler,
				warmup_updates=args.warmup_updates,
				freeze_epochs=args.freeze_epochs,
				gpus=args.gpus,
				evaluate_pw_iter=args.evaluate_pw_iter,
				rule_stopsep=args.stopsep,
				rule_stopcls=args.stopcls,
				stop_priority=args.stop_priority,
				root_dir=args.root_dir,
				dumptext=args.dumptext,
				eval_pararules=('pararules' in args.dataset),
			)

	elif args.model == 'proofwriter_inference_analysis':
		model = model_dict[args.model](
				ruleselector_ckpt=args.ruleselector_ckpt,
				factselector_ckpt=args.factselector_ckpt,
				reasoner_ckpt=args.reasoner_ckpt,
				ques_augmented=args.ques_augmented,
				arch=args.arch,
				train_batch_size=args.train_batch_size,
				eval_batch_size=args.eval_batch_size,
				accumulate_grad_batches=args.accumulate_grad_batches,
				learning_rate=args.learning_rate,
				max_epochs=args.max_epochs,
				optimizer=args.optimizer,
				adam_epsilon=args.adam_epsilon,
				weight_decay=args.weight_decay,
				lr_scheduler=args.lr_scheduler,
				warmup_updates=args.warmup_updates,
				freeze_epochs=args.freeze_epochs,
				gpus=args.gpus,
				evaluate_pw_iter=args.evaluate_pw_iter,
				rule_stopsep=args.stopsep,
				rule_stopcls=args.stopcls,
				stop_priority=args.stop_priority,
				root_dir=args.root_dir,
				dumptext=args.dumptext,
				eval_pararules=('pararules' in args.dataset),
				analysis=args.analysis,
			)

	elif (args.model == 'ruletaker' or args.model == 'qasc'):
		model = model_dict[args.model](
				arch=args.arch,
				train_batch_size=args.train_batch_size,
				eval_batch_size=args.eval_batch_size,
				accumulate_grad_batches=args.accumulate_grad_batches,
				learning_rate=args.learning_rate,
				max_epochs=args.max_epochs,
				optimizer=args.optimizer,
				adam_epsilon=args.adam_epsilon,
				weight_decay=args.weight_decay,
				lr_scheduler=args.lr_scheduler,
				warmup_updates=args.warmup_updates,
				freeze_epochs=args.freeze_epochs,
				gpus=args.gpus,
				hf_name=args.hf_name,
			)

	elif args.model == 'mac':
		# find the vocab size and decide the number of input dimensions
		# NOTE: If using GloVe initializations, we also need to ensure to maintain the word
		# order according to the word2idx mapping.

		n_vocab = len(dm.word2idx)
		inp_dim = args.inp_dim

		model = model_dict[args.model](
				n_vocab,
				inp_dim,
				arch=args.arch,
				train_batch_size=args.train_batch_size,
				eval_batch_size=args.eval_batch_size,
				accumulate_grad_batches=args.accumulate_grad_batches,
				learning_rate=args.learning_rate,
				max_epochs=args.max_epochs,
				optimizer=args.optimizer,
				adam_epsilon=args.adam_epsilon,
				weight_decay=args.weight_decay,
				lr_scheduler=args.lr_scheduler,
				warmup_updates=args.warmup_updates,
				freeze_epochs=args.freeze_epochs,
				gpus=args.gpus,
				embed_hidden=args.embed_hidden,
				max_step=args.max_step,
				self_attention=args.self_attention,
				memory_gate=args.memory_gate,
				pos_aware=args.pos_aware,
				dropout=args.dropout,
			)


	return dm, model, trainer


if __name__ == '__main__':
	start_time         = time.time()
	args               = parse_args()
	args.name          = f'{args.model}_{args.dataset}_{args.arch}_{time.strftime("%d_%m_%Y")}_{str(uuid.uuid4())[: 8]}'

	# sanity check
	if args.resume_training:
		assert args.ckpt_path != ''
	if args.evaluate_ckpt:
		if args.model == 'proofwriter_inference' or args.model == 'proofwriter_inference_analysis' or args.model == 'proofwriter_iterative_eval':
			pass
		elif(args.arch == 'roberta_large_race' and args.model == 'qasc'):
			assert args.hf_name != ''
		else:
			assert args.ckpt_path != ''
		assert args.eval_splits != ''
	if args.multitask:
		# ruleselector has two separate classifiers for stop prediction and rule prediction
		# CE loss cannot be used. Also, using SEP for stop is hard for this. So, default to using CLS for stop
		assert not args.celoss
		assert not args.stopsep
		assert args.stopcls
	if args.analysis != 'default':
		assert args.model == 'proofwriter_inference_analysis'

	# Update trainer specific args that are used internally by Trainer (which is initialized from_argparse_args)
	args.precision = 16 if args.fp16 else 32
	if args.resume_training:
		args.resume_from_checkpoint = args.ckpt_path

	# update the monitor dict (if applicable)
	if args.minimize_loss:
		# change the metric to minimize loss instead of maximizing accuracy/F1
		monitor_dict['proofwriter_ruleselector'] = ('valid_loss_epoch', 'min')

	# Load the datamodule, model, and trainer used for training (or evaluation)
	if not args.evaluate_ckpt:
		dm, model, trainer = main(args)
	else:
		dm, model, trainer = main(args, splits=args.eval_splits.split(','))

	print(vars(args))

	if not args.evaluate_ckpt:
		# train the model from scratch (or resume training from the checkpoint)
		trainer.fit(model, dm)
		print('Testing the best model...')
		trainer.test(ckpt_path='best')
		if not args.save_checkpoint:
			os.remove(trainer.checkpoint_callback.best_model_path)
	else:
		# evaluate the pretrained model on the provided splits
		if (args.arch == 'roberta_large_race' and args.model == 'qasc') or (args.model == 'proofwriter_inference') or (args.model == 'proofwriter_inference_analysis') or (args.model == 'proofwriter_iterative_eval'):
			model_ckpt = model
		else:
			model_ckpt = model.load_from_checkpoint(args.ckpt_path)
		model_ckpt = restore_config_params(model_ckpt, args)
		print('Testing the best model...')
		for split in args.eval_splits.split(','):
			print(f'Evaluating on split: {split}')
			if split == 'train':
				loader = dm.train_dataloader(shuffle = False)
			elif split == 'dev': #note shuffle is false by default for dev and test
				loader = dm.val_dataloader()
				trainer.validate(model=model_ckpt, val_dataloaders=loader)
			elif split == 'test':
				loader = dm.test_dataloader()
				trainer.test(model=model_ckpt, test_dataloaders=loader)

			if(args.save_predictions is not None and  args.save_predictions != ''):
				pred_targets = model_ckpt.predictions
				df = pd.DataFrame(pred_targets.cpu().numpy())
				#save the predictions and target dataframe as a csv file
				df.columns=["preds", "targets"]
				df.to_csv(args.save_predictions, index = False)



	print(f'Time Taken for experiment : {(time.time()-start_time) / 3600}h')
