论文阅读:MHA、MQA、GQA的差异与共性

MHA、MQA、GQA差异

MHA:Multi-Head Attention,QKV 三部分有相同数量的头,且一一对应。每次做 Attention,head_i 的 QKV 做好自己的运算就可以,输出时各个头加起来就行。

MQA:Multi-Query Attention,让 Q 仍然保持原来的头数,但 K 和 V 只有一个头,相当于所有的 Q 头共享一组 K 和 V 头,所以叫做 Multi-Query 了。

  • 实现改变了会不会影响效果呢?确实会影响,但相对它能带来的收益,性能微降是可以接受的。能带来多大的收益呢?实验发现一般能提高 30%-40%左右的吞吐。收益主要来源于降低了 KV cache。实际上 MQA 运算量和 MHA 是差不多的,可理解为读取一组 KV 头之后,给所有 Q 头用,但因为之前提到的内存和计算的不对称,所以是有利的。

GQA:Grouped-Query Attention,是 MHA 和 MQA 的折衷方案,既不想损失性能太多,又想获得 MQA 带来的推理加速好处。具体思想:不是所有 Q 头共享一组 KV,而是进行分组,一定头数 Q 共享一组 KV,比如上面图片就是两组 Q 共享一组 KV。

MHA、MQA、GQA共性

实际上,MHA、MQA可以看做是GQA两个特例版本:

  • MQA对应GQA-1,即只有一个分组,对应一个K和V;
  • MHA对应GQA-H,对应H个head,对应H个K和V;

从MHA模型得到MQA和GQA

  • 从MHA得到MQA:将MHA中H个head的的K和V,分别做mean pooling后得到一个K和V,用得到的K和V继续训练;
  • 从MHA得到GQA:将MHA中H个head的的K和V,分别做mean pooling得到H个K和V,用得到的K和V继续训练;

MHA、MQA、GQA效果

在LLAMA2中,在不同的数据数据集上对比的效果(注意:为了维持参数量一致,对于MQA、GQA的FFN layer的维度,会有一定的拓宽):

参考

请他喝一杯

取消

感谢您的支持,如有任何问题,您可以在打赏留言中留下您的微信

微信扫码
支付宝扫码

打开微信扫一扫,即可进行在打赏中留言哦