【LLM推理】Lookahead:一種無損推理加速機制

2024年2月6日 22点热度 0人点赞

引言

現有LLMs的實際應用面臨著推理速度慢的問題,現有優化推理方法如:量化(int8、GPTQ、KV Cache INT8等)、稀疏化、剪枝、知識蒸餾和張量分解等操作來減少LLMs的大小和降低推理速度。但這些技術往往會犧牲模型的準確性,既有損優化。而無損優化,常見的優化手段主要集中在推理框架和推理引擎端,如:vLLM、TGI等推理框架,集成PagedAttention、FlashAttention等優化算法降低推理速度。理論分析發現IO帶寬是主要瓶頸:LLMs推理延遲的主要瓶頸是輸入輸出(IO)帶寬,而不是與硬件計算能力相關的浮點運算(FLOPs)。這意味著,盡管LLMs在計算能力上可能很強大,但由於IO限制,它們的推理速度仍然受到限制。

本文介紹了Lookahead框架,這是一個通用的推理加速框架,主要針對RAG場景,旨在通過多分支策略和Trie樹結構來提高推理速度,同時保持生成結果的準確性。

一、RAG

概述:介紹Lookahead之前,先說下RAG的思想,RAG通過結合檢索(Retrieval)和生成(Generation)來增強模型的輸出質量。通過檢索最準確和最新的信息來增強LLMs的生成能力。從生成策略上來講,RAG通常依賴於檢索到的文檔或信息片段來輔助生成過程。在生成策略中,假如在采樣時也能猜測Token序列,那麼便可以避免生成待驗證的Token的過程,基於此,設計了Lookahead方法。

二、Lookahead

2.1 METHODS

  1. 多token策略
  2. Lookahead框架允許模型同時生成多個可能的token序列(分支),而不是傳統的單步生成。這種方法可以並行處理多個token,從而在每個推理步驟中生成更多的token,提高整體的推理速度。
  3. Trie樹數據結構
  4. Trie樹用於高效地存儲和檢索與輸入上下文相關的多個token。每個節點代表一個token,從根節點到葉節點的路徑代表一個完整的token序列。Trie樹的結構使得模型能夠快速找到與當前上下文匹配的token序列。
  5. token序列的插入、消除和修剪
  6. 為了維護Trie樹的效率,Lookahead框架實現了分支插入、分支消除和節點修剪策略。這些策略有助於保持Trie樹的合理大小,避免內存消耗過大,並提高檢索性能。
  7. 驗證和接受(VA)過程
  8. 在每個推理步驟中,Lookahead框架會從Trie樹中檢索到的草案進行驗證。驗證過程會確定每個草案中最長的正確子序列,並將這些子序列作為最終輸出的一部分。

核心思想就是驗證token的來源,與單token序列相比,多token序列可以提升接受率,token前綴樹可進一步降低成本。如圖:

在圖中,使用並行的多分支token序列,驗證6個token隻接受了3個token,但使用前綴樹建模的分層多分支token序列,接受了4個token,表明了有效性。

下圖描述了Mask策略實現一次驗證多個token序列或token前綴樹。下節將詳細介紹前綴樹的構建過程。

2.2 Trie樹

  1. Trie樹的定義:Lookahead框架中,Trie樹的每個節點代表一個標記ID,從根節點到葉節點的路徑代表一個分支token序列。這種結構使得模型能夠快速檢索到與給定上下文相關的多個token序列。
  2. Trie樹的更新:為了維護Trie樹的效率和大小,Lookahead框架實現了分支插入、分支消除和節點修剪等更新策略。這些策略有助於保持Trie樹的適度大小,避免內存消耗過大和檢索性能下降。
  3. 分支插入:在處理輸入提示(prompt)或輸出時,Lookahead框架會將提示或輸出轉換為分支token序列,並將其插入到Trie樹中。這有助於利用上下文信息來生成相關的token序列。
  4. 分支消除:當對某個提示的回答生成完成後,與該提示相關的分支token序列會被從Trie樹中移除,因為這些分支可能不再適用於其他提示的生成。
  5. 節點修剪:為了控制Trie樹的大小,當樹的大小超過預設閾值時,會動態移除最不頻繁的節點。這樣可以優化內存消耗並提高檢索性能。
  6. Trie樹的檢索:Lookahead框架通過提供前綴(一系列Token)來從Trie樹中檢索多個分支token序列。Token前綴的長度會影響檢索到的分支數量和相關性。較短的Token前綴會檢索到更多的分支,而較長的前綴則更具體,檢索到的分支與輸入上下文的相關性更高。

在Lookahead的工作流程中,Trie樹在每個推理步驟前後都會被更新。在token序列檢索階段,Trie樹用於提供候選分支;在驗證和接受(VA)階段,這些分支會被驗證,以確定最終的輸出。

算法流程:

三、插拔實踐

  • qwen
import os
import sys
import time
import torch
from transformers import AutoTokenizer
from transformers.generation import GenerationConfig
from pia.lookahead.models.qwen.modeling_qwen import QWenLMHeadModel
from pia.lookahead.models.qwen.tokenization_qwen import QWenTokenizer
from pia.lookahead.examples import local_path_dict
model_dir = local_path_dict.get('qwen', 'your/model/path')
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model = QWenLMHeadModel.from_pretrained(model_dir
                                       , cache_dir='../'
                                       , torch_dtype=torch.float32
                                       , fp32=True
                                       , low_cpu_mem_usage=True
                                       , device_map={"": device}
                                       ).float().cuda().eval()
model.generation_config = GenerationConfig.from_pretrained(model_dir, trust_remote_code=True)
tokenizer = QWenTokenizer.from_pretrained(model_dir)
stop_words = [tokenizer.encode(x)[0] for x in [',', '.', ' ', ',','。']]
prompt = "杭州在哪裡?"
# prompt = "編一個200字左右的兒童故事"
for use_lookahead in [False, False, True, True]:
    decoding_length = 64
    branch_length = 12
    debug_lookahead = False
    max_new_tokens = 256
    decoding_kwargs = {"use_lookahead": use_lookahead,
                       "debug_lookahead": debug_lookahead,
                       "decoding_length": decoding_length,
                       "branch_length": branch_length,
                       "stop_words": stop_words,
                       "tokenizer": tokenizer}
    model.generation_config.decoding_kwargs=decoding_kwargs
    model.generation_config.do_sample=False  # default is True for qwen, result in different responses in every generation
    ts = time.time()
    response, history = model.chat(tokenizer, prompt, history=None, eos_token_id=151645)
    te = time.time()
    token_count = len(tokenizer.encode(response))
    print(f'lookahead:{use_lookahead} time:{te - ts:.3f}s speed:{token_count/(te-ts):.1f}token/s response:\n{response}\n')
  • chatglm3
import sys
import time
import torch
from pia.lookahead.models.chatglm.tokenization_chatglm_3 import ChatGLMTokenizer
from pia.lookahead.models.chatglm.modeling_chatglm import ChatGLMForConditionalGeneration
from pia.lookahead.examples import local_path_dict
model_dir = local_path_dict.get('chatglm3', 'your/model/path') 
tokenizer = ChatGLMTokenizer.from_pretrained(model_dir)
model = ChatGLMForConditionalGeneration.from_pretrained(model_dir
                                                                , cache_dir='./'
                                                                , torch_dtype=torch.float16
                                                                , low_cpu_mem_usage=True
                                                                , device_map={"":"cuda:0"}
                                                                )
stop_words = set(tokenizer.convert_tokens_to_ids([',', '.', ' ']))

# prompt = "Hello, I'm am conscious and"
prompt = "杭州在哪裡?"
inputs = tokenizer.build_chat_input(prompt, history=[])
input_ids = inputs.input_ids.cuda()
attention_mask = inputs.attention_mask.cuda()
position_ids = None
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
                        tokenizer.get_command("<|observation|>")]
device = model.device
debug_lookahead = False
decoding_length = 64
branch_length = 12
max_new_tokens = 128
for use_lookahead in [False,False,True,True]:
    ts = time.time()
    decoding_kwargs = {"use_lookahead": use_lookahead,
                       "debug_lookahead": debug_lookahead,
                       "decoding_mode": 'hier',
                       "decoding_length": decoding_length,
                       "branch_length": branch_length,
                       "stop_words": stop_words}
    outputs = model.generate(input_ids=input_ids,
                             attention_mask=attention_mask,
                             position_ids=position_ids,
                             pad_token_id=tokenizer.eos_token_id,
                             eos_token_id=eos_token_id,
                             use_cache=True,
                             max_new_tokens=max_new_tokens,
                             repetition_penalty=1.0,
                             do_sample=False,
                             decoding_kwargs=decoding_kwargs
                             )
    output_ids = outputs
    input_length = input_ids.size(-1)
    output_ids = output_ids[:, input_length:].tolist()
    # output_ids = output_ids.tolist()
    output_texts = []
    output_id_list = []
    for token_ids in output_ids:
        output_id_list.append(token_ids)
        text = tokenizer.decode(token_ids)
        output_texts.append(text)
    input_id_list = input_ids.tolist()
    te = time.time()
    print(f'use_lookahead:{use_lookahead} time:{te - ts:.3f} output:{output_texts}')

總結

Lookahead框架的核心思想是利用多分支策略和Trie樹結構來加速推理過程:

多分支策略:傳統的自回歸模型逐個生成下一個詞,而Lookahead框架通過並行生成多個分支(即多個可能的詞序列),然後通過驗證和接受(Verification and Accept, VA)過程來確定最終的輸出。這種方法允許模型在每個推理步驟中生成更多的詞,從而提高整體的推理速度。

Trie樹:在Lookahead框架中,Trie樹用於記錄輸入和輸出的詞列表,使得模型能夠基於上下文預測多條路徑。通過優化Trie樹的更新和檢索過程,框架能夠在保持內存和計算效率的同時,實現快速的推理。

參考文獻

  1. Lookahead: An Inference Acceleration Framework for Large Language Model with Lossless Generation Accuracy,https://arxiv.org/abs/2312.12728
  2. https://github.com/alipay/PainlessInferenceAcceleration