Leila Khalatbari


2024

pdf
Flatness-Aware Gradient Descent for Safe Conversational AI
Leila Khalatbari | Saeid Hosseini | Hossein Sameti | Pascale Fung
Proceedings of the 4th Workshop on Trustworthy Natural Language Processing (TrustNLP 2024)

As generative dialog models become ubiquitous in real-world applications, it is paramount to ensure a harmless generation. There are two major challenges when enforcing safety to open-domain chatbots. Firstly, it is impractical to provide training data reflecting the desired response to all emerging forms of toxicity (generalisation challenge). Secondly, implementing safety features may compromise the quality of the conversation (trade-off challenge). To tackle the challenges, this paper introduces a regularized fine-tuning approach called FlatGD. By employing a safety-tailored loss, we translate better optimization to more safety. To ensure better optimization, FlatGD penalizes sharp trajectories of loss curve, encouraging flatness of the converged local minima. Experimental results on datasets of “BAD” and “prosocial dialog” demonstrate that our model outperforms the current baselines in reducing toxicity while preserving the conversation quality. Moreover, compared to other baselines, FlatGD can better generalize to unseen toxic data.