# Striking a Balance: Alleviating Inconsistency in Pre-trained Models for Symmetric Classification Tasks

<p align="center">
  <img align="center" src="images/consistentbert.pdf" alt="Image" height="420" >
</p>

- Overview: BERT-with-consistency-loss. We use an additional classification token: [CLSPara] for our input,upon which the consistency objective is applied.

## Dependencies

- Compatible with Pytorch 1.8.0 and Python 3.x

## Setup

Install VirtualEnv using the following (optional):

```shell
$ [sudo] pip install virtualenv
```

Create and activate your virtual environment (optional):

```shell
$ virtualenv -p python3 venv
$ source venv/bin/activate
```

Install all the required packages:

```shell
$ pip install pytorch-lightning transformers datasets scipy sklearn wandb attrdict prettytable ipdb
```

## Resources

#### Dataset

Can be found in the zip file which contains data as well as src folder.

Change the dataset parameter in the rest of the commands, accordingly.
1. sst2-eq
2. qqp-new
3. mrpc-new
4. qnli-eq
5. rte-eq

## Finetuning the model on Symmetric dataset

```python
TOKENIZERS_PARALLELISM=True CUDA_VISIBLE_DEVICES=0 python src/training.py -dataset qqp-new -model roberta -n_gpus 1 -additional_cls -add_ds mrpc-new paws -divergence js -model_type dual -consistency -s_off -tbs 12 -seed 1023
```

## Finetuning the model obtained above for sst2, qnli and rte

Search for model checkpoint that got created in the previous step (Will be created in Model folder). Let's call that <ckpt>
```python
TOKENIZERS_PARALLELISM=True CUDA_VISIBLE_DEVICES=0 python src/finetuning.py -dataset sst2-eq -ckpt <ckpt> -lr 2e-5 -tbs 12 -n_gpus 1
```

## Classification

Search for model checkpoint that got created in the previous step (Will be created in Model folder). Let's call that <ckpt_cls>
```python
CUDA_VISIBLE_DEVICES=0 python src/classification.py -dataset sst2-eq -erev -ebs 256 -ckpt <ckpt_cls> -n_gpus 1
```

## Evaluation

Create references first
```python
python src/create_references.py -ckpt <ckpt_cls> -dataset sst2-eq -ebs 512
```

```python
python src/evaluation.py -pretrain_path Models/
```

