import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--persona", action="store_true")
parser.add_argument("--hidden_dim", type=int, default=300)
parser.add_argument("--emb_dim", type=int, default=300)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument("--max_grad_norm", type=float, default=2.0)
parser.add_argument("--max_grad_norm_inner", type=float, default=10.0)
parser.add_argument("--max_enc_steps", type=int, default=400)
parser.add_argument("--max_dec_steps", type=int, default=20)
parser.add_argument("--min_dec_steps", type=int, default=5)
parser.add_argument("--beam_size", type=int, default=5)
parser.add_argument("--save_path", type=str, default="outputs")
parser.add_argument("--save_tb", type=str, default="logs_tensorboard")
parser.add_argument("--pt_model", type=str, default="pt_model",help="path of models saved")
parser.add_argument("--save_pt_data", type=str, default="pt_data")
parser.add_argument("--cuda", action="store_true")
parser.add_argument("--pointer_gen", type=bool,default=True)
parser.add_argument("--is_coverage", action="store_true")
parser.add_argument("--use_oov_emb", action="store_true")
parser.add_argument("--pretrain_emb", action="store_true")
parser.add_argument("--print_params", action="store_true")
parser.add_argument("--texar_config_model", type=str, default="config_model")
parser.add_argument("--test", action="store_true") # Means meta-testing.
parser.add_argument("--model", type=str, default="trs")
parser.add_argument("--weight_sharing",type=bool,default=True) 
parser.add_argument("--label_smoothing", action="store_true")
parser.add_argument("--noam", action="store_true")
parser.add_argument("--universal", type=bool, default=True)
parser.add_argument("--act", action="store_true")
parser.add_argument("--act_loss_weight", type=float, default=0.001)
parser.add_argument("--test_print_lossppl",action="store_true")
parser.add_argument("--check_sentence",type=bool,default=False)
parser.add_argument("--top_k", type=int, default=20)

#  transformer 
parser.add_argument("--hop", type=int, default=6)
parser.add_argument("--heads", type=int, default=4)
parser.add_argument("--depth", type=int, default=40)
parser.add_argument("--filter", type=int, default=50)
parser.add_argument("--ori_trs",action='store_true') # Use original transformer not texar.
parser.add_argument("--softmax_temperature_for_topk_decoding", type=float, default=0.7)
parser.add_argument("--use_beam_search",action='store_true')
parser.add_argument("--ft_iters", type=int, default=11) # Total epoches for texar finetune part.

# * meta
parser.add_argument("--fix_dialnum_train", type=bool, default=False)
parser.add_argument("--dialnum",type=int, default=1)
parser.add_argument("--forget_test", action="store_true")
parser.add_argument("--meta_lr", type=float, default=0.1)
parser.add_argument("--meta_interation", type=int, default=1)
parser.add_argument("--base_optimizer", type=str, default="sgd")
parser.add_argument('--meta_batch_size', type=int, default=1)
parser.add_argument("--meta_optimizer", type=str, default="sgd")
parser.add_argument("--load_frompretrain", type=str, default=None)
parser.add_argument("--k_shot", type=int, default=20)
parser.add_argument("--use_sgd", action="store_true")

# * Memory
parser.add_argument("--use_memory",type=bool,default=True)
parser.add_argument("--memory_path",type=str, default=None) 
parser.add_argument("--adapt_num",type=int,default=10) # adaptation epoches number. Train the local adaptation model.
parser.add_argument("--neighbor_num",type=int,default=3) # KNN hparams.
parser.add_argument("--meta_testing",action="store_true") # diffrenciate meta-training vs meta-testing.
parser.add_argument("--fix_decoder",action="store_true") # whether fix decoder when BP
parser.add_argument("--min_memory",type=int,default=500) # memory length constraint, if it is larger than min_memory, then meet one condition for local adaptation.
# parser.add_argument("--LAper_step",type=int,default=20) # Run local adaptation per xx steps. If it's 1 then run it every time.
parser.add_argument("--mm_save_per",type=int,default=500) # Save memory per 500 steps.
# parser.add_argument("--mm_refresh_per",type=int,default=200) # Refresh memory per 200 steps.
parser.add_argument("--min_iter",type=int,default=0) # iteration constraint: after those, do adaptation for every sample.
parser.add_argument("--use_l2",type=bool,default=True) # Use l2 norm to compute distance when binding.
parser.add_argument("--store_ratio", type=float,default=0.5) # Store certain ratio of memory for each task.
parser.add_argument("--Tau",type=float,default=0.001) # Threshold for variance difference when choosing whether to push the pairs into the memory.
parser.add_argument("--all_tasks",type=bool,default=False) # Whether use other tasks information to help predict one task.
parser.add_argument("--rnn_hidden",type=int,default=300) # Sentence embedding output dim, the same as LM encoder output size.
parser.add_argument("--adapt_query",type=bool,default=True) # Whether do memory adaptation also on query set of meta-training.
parser.add_argument("--adapt_support",type=bool,default=True) # Whether do memory adaptation also on support set of meta-training.
parser.add_argument("--store_query",type=bool,default=False) # Whether store query set information of meta-training.
parser.add_argument("--use_retrieval",type=bool, default=False) # Whether using KNN retrieval results directly for prediction. 
parser.add_argument("--step_apart",type=int,default=1) # Interval steps between memory storing and memory reading.
parser.add_argument("--load_memory",type=str,default=None) # Whether using KNN retrieval results directly for prediction. 
parser.add_argument("--print_time",type=bool,default=False) # Whether print time consumption 
parser.add_argument("--memory_refresh",type=bool,default=False) # Whether rewrite all memory per 1k steps. 
parser.add_argument("--add_noise",type=bool,default=False) # Whether add noise on querys' encoder outputs 
parser.add_argument("--noise_sigma",type=float,default=1) # Whether add noise on querys' encoder outputs 
parser.add_argument("--only_train_query",type=bool,default=False) # Whether add noise on querys' encoder outputs 
parser.add_argument("--train_evaluate",type=bool,default=False) # Evaluate consistency and bleu scores during meta-training.
parser.add_argument("--patience",type=int,default=150) # set patience to control early stop time. 
parser.add_argument("--add_regularizer",type=bool,default=False) # Whether use meta-regularization
parser.add_argument("--use_gui_wordemb",type=bool,default=False) # Use guidance word embedding with global parameters.
parser.add_argument("--no_gui",type=bool,default=False) # Not use guidance embedding with global parameters.
parser.add_argument("--no_overlap_tr_val",type=bool,default=False) # No overlap between support sets and query sets.
parser.add_argument("--setting",type=str,default=None) # Change to different settings of functions I modified.
parser.add_argument("--fix_max_enc_len",type=bool,default=False) # Fix max encoder length for feasible mixup support and query. 
parser.add_argument("--fix_max_dec_len",type=bool,default=False) # Fix max decoder length for feasible mixup support and query. 
parser.add_argument("--mixup",type=bool,default=False) # Fix max encoder length for feasible mixup support and query. 
parser.add_argument("--mix_ratio",type=bool,default=False) # Mix ratio for mixup support and query. 
parser.add_argument("--weighted_value",type=bool,default=False) # Get weighted values based on their similarity score. 
parser.add_argument("--reconstruct",type=bool,default=False) # Whether add reconstruction loss for guidance learning. 
parser.add_argument("--inner_alter_update",type=bool,default=False) # Whether alternatively update other parameters in the inner loop. 
parser.add_argument("--concat_mode",type=str,default='src_emb') # Guidance embeddings concatenation way.
parser.add_argument("--other_lr",type=float,default=0.005) # Rnn, NN optimization learning rate.
parser.add_argument("--from_pretrained",type=bool,default=False) # Whether initialize transformer with no-persona pretrained weights.
parser.add_argument("--only_store_cur_task",type=bool,default=False) # If true, only store current support sets information into memory.
parser.add_argument("--data_mode",type=str,default='sup_b_que') # Different ways to read support and query datasets in the balanced function. 
parser.add_argument("--only_binding",type=bool,default=False) # Do not perform local adaptation, instead use embedding predicted by the binding network.
parser.add_argument("--no_neighbors",type=bool,default=False) # Ablation study: random neighbors
parser.add_argument("--no_diverse",type=bool,default=False) # Ablation study: random memory update
parser.add_argument("--generate_steps",type=int,default=1) # set patience to control early stop time. 

# TODO ===================================== Parameters Settings ========================================================
arg = parser.parse_args()
print(arg)
model = arg.model
persona = arg.persona
# texar_model=arg.texar_model
# texar_config_model = arg.texar_config_model
texar_config_model = 'My_texar.config_model'
ori_trs = arg.ori_trs
top_k=arg.top_k
softmax_temperature_for_topk_decoding=arg.softmax_temperature_for_topk_decoding
forget_test=arg.forget_test
use_beam_search=arg.use_beam_search

# Hyperparameters
hidden_dim= arg.hidden_dim
emb_dim= arg.emb_dim
batch_size= arg.batch_size
lr=arg.lr

max_enc_steps=arg.max_enc_steps
max_dec_step= max_dec_steps=arg.max_dec_steps

min_dec_steps=arg.min_dec_steps 
beam_size=arg.beam_size
test_print_lossppl = arg.test_print_lossppl
check_sentence = arg.check_sentence
adagrad_init_acc=0.1
rand_unif_init_mag=0.02
trunc_norm_init_std=1e-4
max_grad_norm=arg.max_grad_norm
max_grad_norm_inner=arg.max_grad_norm_inner

USE_CUDA = arg.cuda
pointer_gen = arg.pointer_gen
is_coverage = arg.is_coverage
use_oov_emb = arg.use_oov_emb
cov_loss_wt = 1.0
lr_coverage=0.15
eps = 1e-12
#epochs = 10000
epochs = 8000

UNK_idx = 0
PAD_idx = 1
EOS_idx = 2
SOS_idx = 3
GUI_idx = 4


emb_file = "vectors/glove.6B.{}d.txt".format(str(emb_dim))
preptrained = arg.pretrain_emb

save_path = arg.save_path
save_tb = arg.save_tb
save_pt_data = arg.save_pt_data
pt_model = arg.pt_model

test = arg.test
if(not test):
    save_path_dataset = save_path


### transformer 
hop = arg.hop
heads = arg.heads
depth = arg.depth
filter = arg.filter
ft_iters=arg.ft_iters

label_smoothing = arg.label_smoothing
weight_sharing = arg.weight_sharing
noam = arg.noam
universal = arg.universal
act = arg.act
act_loss_weight = arg.act_loss_weight


## Meta-learn
meta_lr = arg.meta_lr
meta_iteration = arg.meta_interation
meta_batch_size = arg.meta_batch_size
base_optimizer = arg.base_optimizer
meta_optimizer = arg.meta_optimizer
fix_dialnum_train = arg.fix_dialnum_train
dialnum=arg.dialnum
load_frompretrain = arg.load_frompretrain
k_shot = arg.k_shot
use_sgd = arg.use_sgd

# * Memory parameter
use_memory=arg.use_memory
memory_path=arg.memory_path
adapt_num=arg.adapt_num
print_params=arg.print_params
neighbor_num= arg.neighbor_num
meta_testing = arg.meta_testing
#fix_decoder = arg.fix_decoder
#min_memory=arg.min_memory
#LAper_step=arg.LAper_step
#mm_save_per=arg.mm_save_per
#mm_refresh_per=arg.mm_refresh_per
min_iter=arg.min_iter
use_l2 = arg.use_l2
store_ratio=arg.store_ratio
Tau = arg.Tau
all_tasks = arg.all_tasks
rnn_hidden=arg.rnn_hidden
#adapt_with_lm = arg.adapt_with_lm
# store_query=arg.store_query
adapt_query = arg.adapt_query
store_query = arg.store_query
adapt_support=arg.adapt_support
use_retrieval=arg.use_retrieval
step_apart=arg.step_apart
load_memory=arg.load_memory
print_time=arg.print_time
memory_refresh=arg.memory_refresh
add_noise=arg.add_noise
noise_sigma=arg.noise_sigma
only_train_query=arg.only_train_query
train_evaluate=arg.train_evaluate
patience=arg.patience
add_regularizer=arg.add_regularizer
use_gui_wordemb=arg.use_gui_wordemb
no_gui=arg.no_gui
no_overlap_tr_val=arg.no_overlap_tr_val
setting=arg.setting
fix_max_enc_len=arg.fix_max_enc_len
fix_max_dec_len=arg.fix_max_dec_len
mixup=arg.mixup
mix_ratio=arg.mix_ratio
weighted_value=arg.weighted_value
reconstruct=arg.reconstruct
inner_alter_update=arg.inner_alter_update
concat_mode=arg.concat_mode
other_lr=arg.other_lr
from_pretrained=arg.from_pretrained
only_store_cur_task=arg.only_store_cur_task
data_mode=arg.data_mode
only_binding=arg.only_binding
no_neighbors=arg.no_neighbors
no_diverse=arg.no_diverse
generate_steps=arg.generate_steps