圖解大模型推理優化之KV Cache

2024年2月6日 8点热度 0人点赞

此前,我們更多專註於大模型訓練方面的技術分享和介紹,然而在完成模型訓練之後,上線推理也是一項非常重要的工作。後續,我們將陸續撰寫更多關於大模型推理優化的技術文章,包括但不限於KV Cache、PageAttention、FlashAttention、MQA、GQA等。

在本文中,我們將詳細介紹KV Cache,這是一種大模型推理加速的方法。正如其名稱所示,該方法通過緩存Attention中的K和V來實現推理優化。

01 大模型推理的冗餘計算

我們先簡單觀察一下基於Decoder架構的大模型的生成過程。假設模型隻是一層Self Attention,用戶輸入“中國的首都”,模型續寫得到的輸出為“是北京”,模型的生成過程如下:

  1. 將“中國的首都”輸入模型,得到每個token的註意力表示(綠色部分)。使用“首都”的註意力表示,預測得到下一個token為“是”(實際還需要將該註意力表示映射成概率分佈logits,為了方便敘述,我們忽略該步驟)。
  2. 將“是”拼接到原來的輸入,得到“中國的首都是”,將其輸入模型,得到註意力表示,使用“是”的註意力表示,預測得到下一個token為“北”。
  3. 將“北”拼接到原來的輸入,依此類推,預測得到“京”,最終得到“中國的首都是北京”

在每一步生成中,僅使用輸入序列中的最後一個token的註意力表示,即可預測出下一個token。但模型還是並行計算了所有token的註意力表示,其中產生了大量冗餘的計算(包含qkv映射,attention計算等),並且輸入的長度越長,產生的冗餘計算量越大。例如:

  1. 在第一步中,我們僅需使用“首都”的註意力表示,即可預測得到“是”,但模型仍然會並行計算出“中國”,“的”這兩個token的註意力表示。
  2. 在第二步中,我們僅需使用“是”的註意力表示,即可預測得到“北”,但模型仍然會並行計算“中國”,“的”,“首都”這三個token的註意力表示。

02 Self Attention

KV Cache正是通過某種緩存機制,避免上述的冗餘計算,從而提升推理速度。在介紹KV Cache之前,我們有必要簡單回顧self attention的計算機制,假設輸入序列長度為,第個token對於整個輸入序列的註意力表示如下公式:

個token對於整個輸入序列的註意力表示的計算步驟大致如下:

  1. 向量映射:將輸入序列中的每個token的詞向量分別映射為三個向量。
  2. 註意力計算:使用分別與每個進行點乘,得到第個token對每個token的註意力分數。
  3. 註意力分數歸一化:對註意力分數進行softmax,得到註意力權重。
  4. 加權求和:註意力權重與對應的向量加權求和,最終得到第個token的註意力表示。

下面將以圖像的方式幫助大傢更形象地理解Self Attention。假設輸入序列對於整個輸入序列的註意力表示為,它的計算過程如下圖所示,

繼續觀察對於整個輸入序列的註意力表示,它的計算過程如下圖所示

03 KV Cache

在推理階段,當輸入長度為 ,我們僅需使用 即可預測出下一個token,但模型卻會並行計算出 ,這部分會產生大量的冗餘計算。而實際上 可直接通過公式 算出,即 的計算隻與、所有 有關。

KV Cache的本質是以空間換時間,它將歷史輸入的token的 緩存下來,避免每步生成都重新計算歷史token的 以及註意力表示 而是直接通過 的方式計算得到 ,然後預測下一個token。

舉個例子,用戶輸入“中國的首都”,模型續寫得到的輸出為“是北京”,KV Cache每一步的計算過程如下。

第一步生成時,緩存 均為空,輸入為“中國的首都”,模型將按照常規方式並行計算:

  1. 並行計算得到每個token對應的 ,以及註意力表示
  2. 使用 預測下一個token,得到“是”。
  3. 更新緩存,令

第二步生成時,計算流程如下:

  1. 僅將“是”輸入模型,對其詞向量進行映射,得到
  2. 更新緩存,令
  3. 計算 ,預測下一個token,得到“北”

第三步生成時,計算流程如下:

  1. 僅將“北”輸入模型,對其詞向量進行映射,得到
  2. 更新緩存,令
  3. 計算 ,預測下一個token,得到“京”。

上述生成流程中,隻有在第一步生成時,模型需要計算所有token的,並且緩存下來。此後的每一步,僅需計算當前token的 ,更新緩存,然後使用 即可算出當前token的註意力表示,最後用來預測一下個token。

Hungging Face對於KV Cache的實現代碼如下,結合註釋可以更加清晰地理解其運算過程:

query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
query = self._split_heads(query, self.num_heads, self.head_dim)  # 當前token對應的query
key = self._split_heads(key, self.num_heads, self.head_dim)  # 當前token對應的key
value = self._split_heads(value, self.num_heads, self.head_dim)  # 當前token對應的value
if layer_past is not None:
    past_key, past_value = layer_past  # KV Cache
    key = torch.cat((past_key, key), dim=-2)  # 將當前token的key與歷史的K拼接
    value = torch.cat((past_value, value), dim=-2)  # 將當前token的value與歷史的V拼接
if use_cache is True:
    present = (key, value)
else:
    present = None
# 使用當前token的query與K和V計算註意力表示
if self.reorder_and_upcast_attn:
    attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
else:
    attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)

KV Cache是以空間換時間,當輸入序列非常長的時候,需要緩存非常多k和v,顯存占用非常大。為了緩解該問題,可以使用MQA、GQA、Page Attention等技術,在後續的文章中,我們也將對這些技術進行介紹。