字节跳动 | 提出大模型遗忘方法,只需2% 的RLHF计算时间即可实现对齐!

引言

随着大型语言模型(LLM)的推广和应用,人们越来越关心大模型输出内容的有害性,这对于客户服务、医疗资讯等领域来说是难以容忍的。那么如何避免 LLM 产生有害回复?

基于以上问题,字节跳动(ByteDance),提出让 LLM 进行遗忘学习的方法来进行对齐,实验结果表明,与RLHF相比,作者只使用2%的计算时间下,遗忘学习仍可以获得更好的对齐性能。

论文:https://arxiv.org/abs/2310.10683

代码:https://github.com/kevinyaobytedance/llm_unlearn

背景介绍

目前业界的主流解决方案为 LLM 对齐 (alignment),即通过建立对比数据(正样本和负样本)用强化学习的方式来对 LLM 进行微调 (Finetuning),也就是 RLHF (Reinforcement Learning from Human Feedback) ,从而保证 LLM 输出符合人类预期和价值观。但对齐过程往往受到 (1) 数据收集;(2) 计算资源的限制。

字节跳动提出让 LLM 进行遗忘学习的方法来进行对齐。本文研究如何在 LLM 上进行 “遗忘” 操作,即忘记有害行为或遗忘学习(Machine Unlearning),作者展示了遗忘学习在三种 LLM 对齐场景上取得的明显效果:

  • (1) 删除有害输出;
  • (2) 移除侵权保护内容;
  • (3) 消除大语言 LLM 幻觉。

遗忘学习有三个优势:

  • (1) 只需负样本(有害样本),负样本比 RLHF 所需的正样本(高质量的人工手写输出)的收集简单的多(比如红队测试或用户报告);
  • (2) 计算成本低;
  • (3) 如果知道哪些训练样本导致 LLM 有害行为时,遗忘学习尤为有效。

作者证明,如果从业者只有较少的资源,因此优先考虑的是停止产生有害输出,而不是试图产生过于理想化的输出,遗忘学习尤为便利。尽管只有负样本,研究表明,和 RLHF 相比,只使用 2% 的计算时间下,遗忘学习仍可以获得更好的对齐性能。

方法介绍

本方法可以在资源有限的情况下,最大程度发挥优势。当没预算请人员写优质样本,或计算资源不足时,应当优先停止 LLM 产生有害输出,而不是试图让其产生有益输出。

有害输出造成的损害远不是有益输出能弥补的。如果一个用户问 LLM100 个问题,他得到一个有害答案,就会失去信任,不管后来 LLM 能给多少有益答案。有害问题的预期输出可以是空格、特殊字符、无意义字符串等,总之,一定要是无害文本。

文中展示了 LLM 遗忘学习的三个成功案例:

  • (1) 停止生成有害回复,如上图所示;这与 RLHF 情境相似,区别是本方法目标是生成无害回复,而不是有益回复。当只有负样本时,这是能期望的最好结果;
  • (2) LLM 使用侵权数据训练后,在作者要求下,成功删除数据,且考虑到成本因素不能重训 LLM;
  • (3) LLM 成功忘记 “幻觉”;

实验结果

本文用 PKU-SafeRLHF 数据作为遗忘数据,TruthfulQA 作为正常数据,下图显示了遗忘学习后 LLM 在忘却的有害提示上输出的有害率。

文中使用的方法为 GA(梯度上升和 GA+Mismatch:梯度上升 + 随机误配)。遗忘学习后的有害率接近于零。

下图显示了未见过的有害提示(未被忘却过)上的输出。即使在没有忘却过的有害提示上,LLM 的有害率也接近于零,证明 LLM 忘记的不仅仅是具体见过的样本,而是泛化到了包含有害这个概念的内容。

同时 LLM 在正常样本上的性能和忘却前保持类似。下表展示了生成的样本。可以看到在有害提示下,LLM 生成的样本都是无意义字符串,即无害输出。

下表显示了该方法和 RLHF 的比较,这里 RLHF 已经用了正例,而遗忘学习的方法只有负例,所以比较一开始本方法就占劣势。但即便如此,遗忘学习也能取得和 RLHF 相似的对齐性能。

下表显示了计算时间的比较,本方法只需RLHF 2%的计算时间。

尽管只有负样本,遗忘学习的方法仍能达到和 RLHF 相似的无害率,而且只使用 2% 的算力。因此如果目标是停止输出有害输出,遗忘学习比 RLHF 更高效。