# CUDA_VISIBLE_DEVICES=1 python3 -m torch.distributed.launch --master_port 6677 \
#     --nproc_per_node=1 train_dense_encoder.py \
#     --max_grad_norm 2.0 \
#     --encoder_model_type hf_bert \
#     --pretrained_model_cfg bert-base-multilingual-uncased \
#     --seed 12345 --sequence_length 256 \
#     --do_lower_case \
#     --warmup_steps 1237 \
#     --batch_size 64 \
#     --gradient_accumulation_steps 1 \
#     --gradient_checkpointing \
#     --train_file 'data/xorqa_dpr_data_query=L_hard_negative=1/dpr_train_data.json' \
#     --dev_file 'data/xorqa_dpr_data_query=L_hard_negative=1/dpr_dev_data.json' \
#     --output_dir ./checkpoints/mdpr_xorqa_bs32_ga1 \
#     --learning_rate 2e-05 --num_train_epochs 40 \
#     --val_av_rank_start_epoch 0 \
#     --dev_batch_size 6 \
#     --save_epochs 5 \
#     --fp16

CHECKPOINT_DIR="checkpoints/mcontriever_xor_retrieve_mt5_c16n100_bs16_ga1_inbatch_cache_avg_templabel10"
CHECKPOINT_FILE="dpr_biencoder.9.954"
EMBEDDING_DIR="embeddings_enwiki_e9"
DEVICES="3"

CUDA_VISIBLE_DEVICES=${DEVICES} python3 train_dense_encoder_with_llm.py \
    --max_grad_norm 2.0 \
    --encoder_model_type hf_bert \
    --pretrained_model_cfg facebook/mcontriever \
    --seed 12345 \
    --sequence_length 256 \
    --warmup_steps 1237 \
    --num_contexts 16 \
    --batch_size 16 \
    --gradient_accumulation_steps 1 \
    --inbatch_negative \
    --temperature 10 \
    --train_file "../unsupervised-passage-reranking/outputs/xor_retrieve_train_mcontriever-mt5-100/rank0.json" \
    --dev_file "data/xorqa_dpr_data_query=L_hard_negative=1/dpr_dev_data.json" \
    --output_dir ${CHECKPOINT_DIR} \
    --learning_rate 2e-05 \
    --num_train_epochs 10 \
    --dev_batch_size 12 \
    --val_av_rank_start_epoch 0 \
    --global_loss_buf_sz 2000000 \
    --eval_per_epoch 4 \
    --grad_cache \
    --q_chunk_size 16 \
    --ctx_chunk_size 8 \
    --restart \
    --fp16 \
    --wandb_project "UMR" \
    --wandb_name "mcontriever-xor_retrieve_mt5-c16n100_bs16_ga1_inbatch_cache_avg_templabel10"

    # python3 -m torch.distributed.launch --nproc_per_node 1 --master_port 6696
    # --pretrained_model_cfg bert-base-multilingual-uncased \
    # --do_lower_case \
    # --model_file "checkpoints/mcontriever_mt5_c8_ga16_iter2/dpr_biencoder.9.15250" \
    # --model_file "checkpoints/mdpr_nq/dpr_biencoder_best.pt" \
    # --model_file "checkpoints/mdpr_nq_mt5_c8_ga16/dpr_biencoder.9.15250" \

CUDA_VISIBLE_DEVICES=${DEVICES} python generate_dense_embeddings.py \
    --model_file ${CHECKPOINT_DIR}/${CHECKPOINT_FILE} \
    --encoder_model_type hf_bert \
    --sequence_length 256 \
    --batch_size 256 \
    --ctx_file data/enwiki_20190201_w100.tsv \
    --shard_id 0 --num_shards 1 \
    --out_file ${CHECKPOINT_DIR}/${EMBEDDING_DIR}/enwiki_emb \
    --fp16
    
    # --ctx_file data/psgs_w100.tsv \
    # --pretrained_model_cfg facebook/contriever \
    # --model_file checkpoints/mdpr_nq/dpr_biencoder_best.pt \

CUDA_VISIBLE_DEVICES=${DEVICES} python3 dense_retriever.py \
    --model_file ${CHECKPOINT_DIR}/${CHECKPOINT_FILE} \
    --encoder_model_type hf_bert \
    --sequence_length 256 \
    --ctx_file data/enwiki_20190201_w100.tsv \
    --qa_file data/xor_dev_retrieve_eng_span_v1_1.jsonl \
    --encoded_ctx_file "${CHECKPOINT_DIR}/${EMBEDDING_DIR}/enwiki_emb_*" \
    --out_file ${CHECKPOINT_DIR}/${EMBEDDING_DIR}/xor_dev_retrieve_eng.json \
    --n-docs 100 \
    --validation_workers 1 --batch_size 128 --search_batch_size 512
