# QASE

Note: Pre-trained model weights will be released upon paper accepetance.


### Enviornment Setup
1. Use Python 3.9.
2. Run `pip install -r requirements.txt`.
3. Create a `.env` file, and put in `HF_TOKEN=YOUR_HUGGINGFACE_TOKEN`, and `OPENAI_API_KEY=YOUR_OPENAI_API_TOKEN`.


### Datasets
1. [MultiSpanQA](https://aclanthology.org/2022.naacl-main.90/): [leaderboard](https://multi-span.github.io/)
2. [SQuAD](https://arxiv.org/pdf/1606.05250): [leaderboard](https://rajpurkar.github.io/SQuAD-explorer/)
3. [Quoref](https://aclanthology.org/D19-1606/): [leaderboard](https://leaderboard.allenai.org/quoref/submissions/about)


### Train and Fine-Tune

##### To train Flan-T5-Large<sub>QASE</sub> on, e.g. MultiSpanQA
1. Go to `configs/train_qase_config.py` and adjust any config settings if needed.
2. Go to the `train` directory.
3. Run `python train_quase.py --dataset multispanqa --base_model flan-t5-large`.
The trained model weights will be stored in the specified `ckpt_file` in the config settings.

##### To fine-tune Flan-T5-Large on, e.g. MultiSpanQA
1. Go to `configs/finetune_flan_t5_config.py` and adjust any config settings if needed.
2. Go to the `train` directory.
3. Run `python finetune_flan_t5.py --dataset MultiSpanQA --base_model flan-t5-large`.
The trained LoRA weights will be stored in the specified `ckpt_dir` in the config settings.

##### To fine-tune Llama 2/Alpaca on, e.g. MultiSpanQA, with LoRA
1. Go to `configs/finetune_llama_config.py` and adjust any config settings if needed.
2. Go to the `train` directory.
3. Run `python finetune_llama.py --dataset MultiSpanQA --base_model llama2`.
The trained LoRA weights will be stored in the specified `ckpt_dir` in the config settings.


### Inference and Evaluate

##### To perform inference on the test/val batch of, e.g. MultiSpanQA, with a trained Flan-T5-Large<sub>QASE</sub> model weights
1. Go to `configs/inference_qase_config.py` and adjust any config settings if needed.
2. Go to the `inference/qase` directory.
3. Run `python inference_qase.py --dataset multispanqa --batch val --base_model flan-t5-large --ckpt_file {ckpt_file}`.
The output prediction file will be stored in the specified `predictions_output_file` in the config settings.

##### To perform inference on the test/val batch of, e.g. MultiSpanQA with zero-shot Flan-T5-Large or fine-tuned Flan-T5-Large
1. Go to `configs/inference_flan_t5_config.py` and adjust any config settings if needed.
2. Go to the `inference/flan-t5` directory.
3. To run with zero-shot, run `python inference_flan_t5.py --dataset multispanqa --batch test --base_model flan-t5-large`.
4. To run fine-tuned model, add `--ckpt_dir {ckpt_dir}`.
The output prediction file will be stored in the specified `predictions_output_file` in the config settings.

##### To perform inference on the test/val batch of, e.g. MultiSpanQA with zero-shot Llama 2 or fine-tuned Llama 2
1. Go to `configs/inference_llama_config.py` and adjust any config settings if needed.
2. Go to the `inference/llama` directory.
3. To run with zero-shot, run `python inference_llama.py --dataset multispanqa --batch test --base_model llama2`.
4. To run fine-tuned model, add `--ckpt_dir {ckpt_dir}`.
The output prediction file will be stored in the specified `predictions_output_file` in the config settings.

##### To run eval scripts
1. Go to the `inference` directory.
2. For MultiSpanQA, run `python multispanqa_eval_script.py --pred_file {pred_file} --gold_file ../datasets/MultiSpanQA/valid.json`.
3. For SQuAD, run `python squad_eval_script.py --pred_file {pred_file} --gold_file ../datasets/SQuAD/test_readable.json`.


### Experiment Results
#### MultiSpanQA Val Set
|                          | em_precision |  em_recall |    em_f1   | overlap_precision | overlap_recall | overlap_f1 |
|:------------------------:|:------------:|:----------:|:----------:|:-----------------:|:--------------:|:----------:|
|          Llama2          |    12.232    |    4.186   |    6.237   |       40.766      |     35.383     |   37.884   |
|     Llama2 Fine-Tuned    |    48.238    |   50.131   |   49.166   |       65.881      |     68.865     |   67.339   |
|      **Llama2+QASE**     |  **51.906**  | **50.602** | **51.245** |     **67.964**    |   **71.428**   | **69.653** |
|          Alpaca          |    31.546    |   10.779   |   16.068   |       59.755      |     32.535     |   42.131   |
|     Alpaca Fine-Tuned    |    53.046    |   52.851   |   52.948   |       70.333      |     68.470     |   69.389   |
|      **Alpaca+QASE**     |  **50.659**  | **54.317** | **52.424** |     **68.877**    |   **72.941**   | **70.851** |
|       Flan-T5-Small      |     1.531    |    0.523   |    0.780   |       54.071      |     14.342     |   22.671   |
| Flan-T5-Small Fine-Tuned |    62.059    |   58.974   |   60.477   |       81.776      |     74.075     |   77.735   |
|  **Flan-T5-Small+QASE**  |  **63.636**  | **61.172** | **62.380** |     **83.475**    |   **76.205**   | **79.675** |
|       Flan-T5-Base       |     7.843    |    2.721   |    4.040   |       56.382      |     28.525     |   37.884   |
|  Flan-T5-Base Fine-Tuned |    67.416    |   64.311   |   65.827   |       86.711      |     78.645     |   82.481   |
|   **Flan-T5-Base+QASE**  |  **68.695**  | **64.992** | **66.792** |     **87.902**    |   **79.805**   | **83.658** |
|       Flan-T5-Large      |    24.962    |    8.634   |   12.830   |       63.746      |     41.451     |   50.236   |
| Flan-T5-Large Fine-Tuned |    69.326    |   66.823   |   68.052   |       88.773      |     78.993     |   83.598   |
|  **Flan-T5-Large+QASE**  |  **69.403**  | **66.353** | **67.844** |     **87.653**    |   **81.034**   | **84.214** |
|          GPT-3.5         |    62.698    |   55.677   |   58.980   |       86.505      |     76.892     |   81.416   |
|           GPT-4          |    64.876    |   64.469   |   64.672   |       83.415      |     82.422     |   82.916   |

#### MultiSpanQA Test Set
|                          | em_precision |  em_recall |    em_f1   | overlap_precision | overlap_recall | overlap_f1 |
|:------------------------:|:------------:|:----------:|:----------:|:-----------------:|:--------------:|:----------:|
|          Llama2          |    14.155    |    4.967   |    7.354   |       44.903      |     27.398     |   34.031   |
|     Llama2 Fine-Tuned    |    50.233    |   51.655   |   50.934   |       66.229      |     70.165     |   68.140   |
|      **Llama2+QASE**     |  **52.113**  | **51.389** | **51.748** |     **69.006**    |   **71.828**   | **70.389** |
|          Alpaca          |    29.357    |   10.256   |   15.201   |       58.924      |     33.554     |   42.759   |
|     Alpaca Fine-Tuned    |    52.843    |   52.617   |   52.730   |       69.395      |     68.806     |   69.099   |
|      **Alpaca+QASE**     |  **50.272**  | **54.274** | **52.196** |     **67.556**    |   **72.643**   | **70.008** |
|       Flan-T5-Small      |     0.918    |    0.320   |    0.475   |       54.017      |     14.240     |   22.539   |
| Flan-T5-Small Fine-Tuned |    60.344    |   57.959   |   59.128   |       79.769      |     73.477     |   76.494   |
|  **Flan-T5-Small+QASE**  |  **60.244**  | **57.959** | **59.080** |     **80.892**    |   **73.654**   | **77.103** |
|       Flan-T5-Base       |     7.926    |    2.777   |    4.113   |       53.674      |     29.046     |   37.694   |
|  Flan-T5-Base Fine-Tuned |    65.674    |   63.675   |   64.659   |       84.510      |     78.526     |   81.408   |
|   **Flan-T5-Base+QASE**  |  **66.003**  | **63.782** | **64.874** |     **85.199**    |   **78.105**   | **81.498** |
|       Flan-T5-Large      |    26.707    |    9.401   |   13.907   |       65.055      |     42.620     |   51.501   |
| Flan-T5-Large Fine-Tuned |    68.505    |   66.346   |   67.408   |       87.725      |     78.928     |   83.094   |
|  **Flan-T5-Large+QASE**  |  **67.556**  | **66.293** | **66.918** |     **86.262**    |   **82.274**   | **84.221** |
|          GPT-3.5         |    62.369    |   57.371   |   59.766   |       85.411      |     78.604     |   81.866   |
|           GPT-4          |    64.113    |   63.942   |   64.027   |       82.791      |     82.673     |   82.731   |

#### SQuAD
|                          | exact_match |     f1     |
|:------------------------:|:-----------:|:----------:|
|          Llama2          |    13.443   |   28.931   |
|     Llama2 Fine-Tuned    |    36.679   |   47.055   |
|      **Llama2+QASE**     |  **37.219** | **47.686** |
|          Alpaca          |    18.259   |   33.871   |
|     Alpaca Fine-Tuned    |    27.881   |   43.950   |
|      **Alpaca+QASE**     |  **37.313** | **47.622** |
|       Flan-T5-Small      |    13.878   |   28.710   |
| Flan-T5-Small Fine-Tuned |    77.332   |   85.513   |
|  **Flan-T5-Small+QASE**  |  **77.663** | **85.901** |
|       Flan-T5-Base       |    37.596   |   51.747   |
|  Flan-T5-Base Fine-Tuned |    82.090   |   89.558   |
|   **Flan-T5-Base+QASE**  |  **82.204** | **90.240** |
|       Flan-T5-Large      |    16.149   |   37.691   |
| Flan-T5-Large Fine-Tuned |    83.159   |   90.712   |
|  **Flan-T5-Large+QASE**  |  **84.125** | **91.701** |
|          GPT-3.5         |    36.944   |   65.637   |
|    GPT-3.5 Fine-Tuned    |    75.697   |   86.565   |
|           GPT-4          |    39.347   |   69.158   |

Note: GPT-3.5 was fine-tuned with 10% of the training data.

#### Quoref
|                          | exact_match |     f1    |
|:------------------------:|:-----------:|:---------:|
|          Llama2          |     5.02    |   28.91   |
|     Llama2 Fine-Tuned    |    45.52    |   52.09   |
|      **Llama2+QASE**     |  **54.28**  | **60.44** |
|          Alpaca          |      -      |     -     |
|     Alpaca Fine-Tuned    |      -      |     -     |
|      **Alpaca+QASE**     |      -      |     -     |
|       Flan-T5-Small      |     1.58    |    5.96   |
| Flan-T5-Small Fine-Tuned |    58.21    |   63.30   |
|  **Flan-T5-Small+QASE**  |  **60.70**  | **66.88** |
|       Flan-T5-Base       |    27.08    |   34.38   |
|  Flan-T5-Base Fine-Tuned |    72.77    |   80.90   |
|   **Flan-T5-Base+QASE**  |  **75.17**  | **81.18** |
|       Flan-T5-Large      |    15.96    |   24.10   |
| Flan-T5-Large Fine-Tuned |    75.17    |   80.49   |
|  **Flan-T5-Large+QASE**  |  **76.19**  | **82.13** |
|          GPT-3.5         |    50.22    |   59.51   |
|           GPT-4          |    68.07    |   78.34   |
