MetaWeighting: Learning to Weight Tasks in Multi-Task Text Classification

Abstract

Task weighting, which assigns weights on the including tasks during training, significantly matters the performance of Multi-task Learning (MTL); thus, recently, there has been an explosive interest in it. However, existing task weighting methods assign weights only based on the training loss, while ignoring the gap between the training loss and generalization loss. It degenerates MTL's performance. To address this issue, the present paper proposes a novel task weighting algorithm, which automatically weights the tasks via a learning-to-learn paradigm, referred to as MetaWeighting.

Software implementation

Start the training process by running train.py. You can change the model and Multi-task learning method's config by modify config.py. models provides multiple model implementations for selection. Results generated by code are saved in runs (will be automatically generated after training).

Dependencies

You'll need a working Python environment to run the code. The code is based on python 3.6 and pytorch 1.6. The recommended way to set up your environment is through the Anaconda Python distribution which provides the conda package manager.

The required dependencies are specified in the file requirements.txt.

Run the following command in the repository folder (where requirements.txt is located) to create a separate environment and install all required dependencies in it:

Reproducing the results

Before running any code you must activate the conda environment:

or, if you're on Windows:

This will enable the environment for your current terminal session.

Optional arguments:

ParameterDefault & Other ChoiceDescription
-dsentiment [20news]Datasets of Sentiment Analysis and Topic Classification
-mtextcnn [lstm]Model of feature extractor
-tmeta [single/uniform/fudan/mgda/gradnorm
/uncertain/tchebycheff_adv/bandit]
Multi-task learning method
-g0Select when there are multiple GPUs
-cNoneRemark the result
-alpha0.1Step size for updating weights
-split_scale0.1Query-Split radio

Example:

Run meta on GPU 0 with dataset Sentiment Analysis, alpha=0.1, rho=0.1

Run meta on GPU 2 with dataset Topic Classification, alpha=0.5, rho=0.1

The output will be placed in runs/20news/meta. Use Tensorboard to check the result:

This will start the server and open your default web browser to the Tensorboard interface. You can check and download the result in the page.