A Deep Reinforced Sequence-to-Set Model for Multi-Label Classification

Pengcheng Yang, Fuli Luo, Shuming Ma, Junyang Lin, Xu Sun


Abstract
Multi-label classification (MLC) aims to predict a set of labels for a given instance. Based on a pre-defined label order, the sequence-to-sequence (Seq2Seq) model trained via maximum likelihood estimation method has been successfully applied to the MLC task and shows powerful ability to capture high-order correlations between labels. However, the output labels are essentially an unordered set rather than an ordered sequence. This inconsistency tends to result in some intractable problems, e.g., sensitivity to the label order. To remedy this, we propose a simple but effective sequence-to-set model. The proposed model is trained via reinforcement learning, where reward feedback is designed to be independent of the label order. In this way, we can reduce the dependence of the model on the label order, as well as capture high-order correlations between labels. Extensive experiments show that our approach can substantially outperform competitive baselines, as well as effectively reduce the sensitivity to the label order.
Anthology ID:
P19-1518
Volume:
Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics
Month:
July
Year:
2019
Address:
Florence, Italy
Venue:
ACL
SIG:
Publisher:
Association for Computational Linguistics
Note:
Pages:
5252–5258
Language:
URL:
https://aclanthology.org/P19-1518
DOI:
10.18653/v1/P19-1518
Bibkey:
Cite (ACL):
Pengcheng Yang, Fuli Luo, Shuming Ma, Junyang Lin, and Xu Sun. 2019. A Deep Reinforced Sequence-to-Set Model for Multi-Label Classification. In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, pages 5252–5258, Florence, Italy. Association for Computational Linguistics.
Cite (Informal):
A Deep Reinforced Sequence-to-Set Model for Multi-Label Classification (Yang et al., ACL 2019)
Copy Citation:
PDF:
https://preview.aclanthology.org/author-url/P19-1518.pdf
Code
 lancopku/Seq2Set
Data
RCV1