# Pyramid-BERT Code to Reproduce the Experimental Results in the ACL Submission
## Our code is built upon Tensorflow official github repository https://github.com/tensorflow/models/tree/master/official
## Next we will go through essential steps that are required to reporduce our experimental results in the paper submission. 

## Contents
  * [Contents](#contents)
  * [Set Up Computing Environment](#set-up-computing-environment)
  * [Process Datasets](#process-datasets)
  * [Fine-tuning with Pyramid-BERT](#fine-tuning-with-pyramid-bert)



## Set Up Computing Environment

```shell
export PYTHONPATH="$PYTHONPATH:/path/to/models"
```

Install `tf-nightly` to get latest updates:

```shell
pip install tf-nightly-gpu
```

With TPU, GPU support is not necessary. First, you need to create a `tf-nightly`
TPU with [ctpu tool](https://github.com/tensorflow/tpu/tree/master/tools/ctpu):

```shell
ctpu up -name <instance name> --tf-version=”nightly”
```

Second, you need to install TF 2 `tf-nightly` on your VM:

```shell
pip install tf-nightly
```

## Process Datasets
### Fine-tuning [Our pyramid-BERT only focuses on fine-tuning]

To prepare the fine-tuning data for final model training, use the
[`../data/create_finetuning_data.py`](../data/create_finetuning_data.py) script.
Resulting datasets in `tf_record` format and training meta data should be later
passed to training or evaluation scripts. The task-specific arguments are
described in following sections:

* GLUE

Users can download the
[GLUE data](https://gluebenchmark.com/tasks) by running
[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
and unpack it to some directory `$GLUE_DIR`.
Also, users can download [Pretrained Checkpoint](#access-to-pretrained-checkpoints) and locate on some directory `$BERT_DIR` instead of using checkpoints on Google Cloud Storage.

```shell
export GLUE_DIR=~/glue
export BERT_DIR=gs://cloud-tpu-checkpoints/bert/keras_bert/uncased_L-24_H-1024_A-16

export TASK_NAME=MNLI
export OUTPUT_DIR=gs://some_bucket/datasets
python ../data/create_finetuning_data.py \
 --input_data_dir=${GLUE_DIR}/${TASK_NAME}/ \
 --vocab_file=${BERT_DIR}/vocab.txt \
 --train_data_output_path=${OUTPUT_DIR}/${TASK_NAME}_train.tf_record \
 --eval_data_output_path=${OUTPUT_DIR}/${TASK_NAME}_eval.tf_record \
 --meta_data_file_path=${OUTPUT_DIR}/${TASK_NAME}_meta_data \
 --fine_tuning_task_type=classification --max_seq_length=128 \
 --classification_task_name=${TASK_NAME}
```


## Fine-tuning with Pyramid-BERT

### Sentence and Sentence-pair Classification Tasks

The main entry script is models.official.nlp.bert.run_classifier_pyramid_bert.py
You can set the /models as the root directory. Then run python -m official.nlp.bert.run_classifier_pyramid_bert

Within run_classifier_pyramid_bert.py, function custom_main is the entry point. Several key arguments are listed below:

1. mode:  'train_and_eval' for fine-tuning and 'predict' for inference.
2. reduction_method:
    'FIRSTX_NO_PRUNE': the 'Input-first-k-select' mentioned in the paper.
    'FIRST_X': the 'First-k-select' in the paper.
    'RANDOM': the 'Random-select' in the paper.
    'ATT_ONLY': the 'Attention-select' in the paper.
    'CLS_NO_DUMMY': the 'Corset-select-k-1' (forgive my naming ...)
    'KMEANSPLUS': the 'Coreset-select-x' (again, forgive my naming ...)
3. retention_config: corresponds to the 'sequence length configuration' in the paper.
   For the sequence-length generation function mentioned in the paper:
   try: for exponential decay function in the paper, try function get_retention_config_multiple_reduction(task=task, method='expo_decay',last_num_token=0.2,final_layer_idx=2)

4. kcenters_param_s: corresponds to 'number of centers to add per iteration' 
    A float number say 0.5 means it will add \lceil 0.5 \cdot k \rceil, where k is the number of tokens to retain.


For implementation of 'Coreset-select', see official.nlp.keras_nlp.layer.kmeans_cosine_similarity_para_s.py   
For where this method gets called, see official.nlp.keras_nlp.layer.transformer_encoder_block.py
