此前,我們更多專註於大模型訓練方面的技術分享和介紹,然而在完成模型訓練之後,上線推理也是一項非常重要的工作。後續,我們將陸續撰寫更多關於大模型推理優化的技術文章,包括但不限於KV Cache、PageAttention、FlashAttention、MQA、GQA等。
在本文中,我們將詳細介紹KV Cache,這是一種大模型推理加速的方法。正如其名稱所示,該方法通過緩存Attention中的K和V來實現推理優化。
01 大模型推理的冗餘計算
我們先簡單觀察一下基於Decoder架構的大模型的生成過程。假設模型隻是一層Self Attention,用戶輸入“中國的首都”,模型續寫得到的輸出為“是北京”,模型的生成過程如下:
- 將“中國的首都”輸入模型,得到每個token的註意力表示(綠色部分)。使用“首都”的註意力表示,預測得到下一個token為“是”(實際還需要將該註意力表示映射成概率分佈logits,為了方便敘述,我們忽略該步驟)。
- 將“是”拼接到原來的輸入,得到“中國的首都是”,將其輸入模型,得到註意力表示,使用“是”的註意力表示,預測得到下一個token為“北”。
- 將“北”拼接到原來的輸入,依此類推,預測得到“京”,最終得到“中國的首都是北京”
![](https://news.xinpengboligang.com/upload/keji/21390ae9fa9130153cdf4dc19877c856.jpeg)
在每一步生成中,僅使用輸入序列中的最後一個token的註意力表示,即可預測出下一個token。但模型還是並行計算了所有token的註意力表示,其中產生了大量冗餘的計算(包含qkv映射,attention計算等),並且輸入的長度越長,產生的冗餘計算量越大。例如:
- 在第一步中,我們僅需使用“首都”的註意力表示,即可預測得到“是”,但模型仍然會並行計算出“中國”,“的”這兩個token的註意力表示。
- 在第二步中,我們僅需使用“是”的註意力表示,即可預測得到“北”,但模型仍然會並行計算“中國”,“的”,“首都”這三個token的註意力表示。
02 Self Attention
KV Cache正是通過某種緩存機制,避免上述的冗餘計算,從而提升推理速度。在介紹KV Cache之前,我們有必要簡單回顧self attention的計算機制,假設輸入序列長度為,第
個token對於整個輸入序列的註意力表示如下公式:
第個token對於整個輸入序列的註意力表示的計算步驟大致如下:
- 向量映射:將輸入序列中的每個token的詞向量分別映射為
三個向量。
- 註意力計算:使用
分別與每個
進行點乘,得到第
個token對每個token的註意力分數。
- 註意力分數歸一化:對註意力分數進行softmax,得到註意力權重。
- 加權求和:註意力權重與對應的向量
加權求和,最終得到第
個token的註意力表示。
下面將以圖像的方式幫助大傢更形象地理解Self Attention。假設輸入序列,
對於整個輸入序列
的註意力表示為
,它的計算過程如下圖所示,
。
![](https://news.xinpengboligang.com/upload/keji/e76896def48d4e59989a0562dbe7f72a.jpeg)
繼續觀察對於整個輸入序列
的註意力表示
,它的計算過程如下圖所示
。
![](https://news.xinpengboligang.com/upload/keji/061a380f91ff1a3fbe25d759b8a9129d.jpeg)
03 KV Cache
在推理階段,當輸入長度為 ,我們僅需使用
即可預測出下一個token,但模型卻會並行計算出
,這部分會產生大量的冗餘計算。而實際上
可直接通過公式
算出,即
的計算隻與
、所有
和
有關。
KV Cache的本質是以空間換時間,它將歷史輸入的token的 和
緩存下來,避免每步生成都重新計算歷史token的
和
以及註意力表示
,而是直接通過
的方式計算得到
,然後預測下一個token。
舉個例子,用戶輸入“中國的首都”,模型續寫得到的輸出為“是北京”,KV Cache每一步的計算過程如下。
第一步生成時,緩存 均為空,輸入為“中國的首都”,模型將按照常規方式並行計算:
- 並行計算得到每個token對應的
,以及註意力表示
。
- 使用
預測下一個token,得到“是”。
- 更新緩存,令
。
![](https://news.xinpengboligang.com/upload/keji/b7e97b6ea36a58dbf7ec603f8e82db6a.jpeg)
第二步生成時,計算流程如下:
- 僅將“是”輸入模型,對其詞向量進行映射,得到
。
- 更新緩存,令
。
- 計算
,預測下一個token,得到“北”
![](https://news.xinpengboligang.com/upload/keji/04dcae9bc01bae92dd637b9cbdd3ba6a.jpeg)
第三步生成時,計算流程如下:
- 僅將“北”輸入模型,對其詞向量進行映射,得到
。
- 更新緩存,令
。
- 計算
,預測下一個token,得到“京”。
![](https://news.xinpengboligang.com/upload/keji/ef0b5cc823994c5ea969c00926954a89.jpeg)
上述生成流程中,隻有在第一步生成時,模型需要計算所有token的,並且緩存下來。此後的每一步,僅需計算當前token的
、
、
,更新緩存
、
,然後使用
、
、
即可算出當前token的註意力表示,最後用來預測一下個token。
Hungging Face對於KV Cache的實現代碼如下,結合註釋可以更加清晰地理解其運算過程:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
query = self._split_heads(query, self.num_heads, self.head_dim) # 當前token對應的query
key = self._split_heads(key, self.num_heads, self.head_dim) # 當前token對應的key
value = self._split_heads(value, self.num_heads, self.head_dim) # 當前token對應的value
if layer_past is not None:
past_key, past_value = layer_past # KV Cache
key = torch.cat((past_key, key), dim=-2) # 將當前token的key與歷史的K拼接
value = torch.cat((past_value, value), dim=-2) # 將當前token的value與歷史的V拼接
if use_cache is True:
present = (key, value)
else:
present = None
# 使用當前token的query與K和V計算註意力表示
if self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
else:
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
KV Cache是以空間換時間,當輸入序列非常長的時候,需要緩存非常多k和v,顯存占用非常大。為了緩解該問題,可以使用MQA、GQA、Page Attention等技術,在後續的文章中,我們也將對這些技術進行介紹。