# BanditMTL: Bandit-based Multi-task Learning for Text Classification




## Abstract

> Task variance regularization, which can be used to improve the generalization of Multi-task Learning (MTL) models, remains unexplored in multi-task text classification. 
> 
> Accordingly, to fill this gap, this paper investigates how the task might be effectively regularized, and consequently proposes a multi-task learning method based on adversarial multi-armed bandit. The proposed method, named BanditMTL, regularizes the task variance by means of a mirror gradient ascent-descent algorithm. Adopting BanditMTL in the multi-task text classification context is found to achieve state-of-the-art performance. The results of extensive experiments back up our theoretical analysis and validate the superiority of our proposals.


## Software implementation

Start the training process by running `train.py`.
You can change the model and Multi-task learning method's config by modify `config.py`.
`models` provides multiple model implementations for selection.
Results generated by code are saved in `runs` (will be automatically generated after training).


## Dependencies

You'll need a working Python environment to run the code. The code is based on `python 3.6` and `pytorch 1.4`.
The recommended way to set up your environment is through the [Anaconda Python distribution](https://www.anaconda.com/download/) which provides the `conda` package manager.

The required dependencies are specified in the file `requirements.txt`.

Run the following command in the repository folder (where `requirements.txt` is located) to create a separate environment and install all required dependencies in it:

```shell
conda create -n env_name python=3.6   # create new environment
source activate env_name
pip install -r requirements.txt
```


## Reproducing the results

Before running any code you must activate the conda environment:

    source activate env_name

or, if you're on Windows:

    activate env_name

This will enable the environment for your current terminal session. 

#### Optional arguments: 

| Parameter         | Default & Other Choice| Description |
| :----------------: | :-----: | :---------|
| -d |   sentiment [20news] | Datasets of `Sentiment Analysis` and `Topic Classification`|
| -m |   textcnn [lstm] | Model of feature extractor|
| -t | bandit [single/uniform/gradnorm/mgda/fudan<br/>/uncertain/tchebycheff/tchebycheff_adv] | Multi-task learning method |
| -g |   0   | Select when there are multiple GPUs |
| -c |  None  | Remark the result |

#### Example:

Run `bandit` on GPU 2 with dataset `Topic Classification`

    python train.py -d 20news -g 2

The output will be placed in `runs/20news/bandit`. Use `Tensorboard` to check the result:

    tensorboard --logdir=RESULT_PATH --port SERVER_PORT --bind_all

This will start the server and open your default web browser to the `Tensorboard` interface. You can check and download the result in the page.
