FuseLLM:大語言模型的知識融合!

2024年2月6日 24点热度 0人点赞

深度學習自然語言處理 原創
作者:wkk

論文:KNOWLEDGE FUSION OF LARGE LANGUAGE MODELS
地址:https://arxiv.org/pdf/2401.10491.pdf
git: https://github.com/fanqiwan/FuseLLM

小夥伴們好久沒見,今天為大傢介紹中山大學聯合騰訊人工智能實驗室的最新研究論文,關於整合LLM知識能力的框架。

引言

當進行LLM工作時,如果從頭開始訓練LLM可以生成具有不同功能和優勢的模型,但這會帶來巨大的成本,並可能導致冗餘功能。或者使用一種具有成本效益和說服力的方法是將現有的預先訓練的LLM合並到一個更有效的模型中。然而,由於已有LLM的架構各不相同,直接混合它們的權重是不切實際的。

在本文中,引入了LLM的知識融合概念,旨在將現有LLM的能力結合起來,並將其轉移到單個LLM中。通過利用源LLM的生成分佈,將其集體知識和獨特優勢外部化,從而有可能將目標模型的能力提升到任何單個源LLM之外。

動機

隨著GPT和LlaMA系列等大型語言模型在各種自然語言處理任務中的不斷成功,創建自己的LLM已成為企業的戰略當務之急。然而,LLM開發的相關成本是天文數字。除了需要大量的訓練數據、先進的技術、大量的計算資源和熟練的勞動力外,開發過程還對能源消耗和環境施加了巨大壓力。

相關工作

  • Model Fusing: 模型融合的常見方法通常采用加權平均多數投票來融合各種模型的預測。最近,有研究人員提出了一種集成框架,旨在利用多個開源LLM的不同優勢。該框架首先采用成對比較方法來檢測候選輸出之間的細微區別。然後,它將排名靠前的候選人結合起來,產生更高的產出,利用它們的優勢,同時減輕它們的弱點。雖然模型集成需要並行部署多個模型,但權重合並通常僅限於具有相同架構的模型。相反,本文提出的方法通過將多個LLM的知識和能力明確地轉移到目標LLM,支持將多個具有不同架構的LLM融合。
  • Knowledge Distillation:知識蒸餾最初被提出用於模型壓縮,包括在一個或多個教師模型的指導下訓練學生模型。在NLP中已有較為廣泛的應用。本文的方法與傳統的知識蒸餾有顯著的區別。首先,在傳統的知識蒸餾中,學生模型通常被限制為比教師更小的尺寸。然而,在本文的場景中,目標模型的大小沒有限制。其次,傳統的知識蒸餾通常會導致學生模型在蒸餾後落後於教師的表現。相比之下,本文預計在融合之後,目標模型將超過任何源模型的性能。

方法

模型架構

上圖展示了傳統模型融合技術和本文的LLM知識融合方法(FUSELLM)的對比。不同的動物圖標代表不同的LLM。FUSELLM能將多個LLM外部知識融合,並將它們的能力轉移到目標LLM。LLM融合的主要目標是將嵌入多個源LLM中的集體知識外部化,並將其能力集成到目標LLM中。

上表展示了FuseLLM的算法過程,其主要實現細節依賴於token對齊和具體的融合策略。

實驗

實驗設置

  • 數據集MiniPile包含22個域中的大約100萬個文檔和1.8億個token,占Llama-2的2萬億個訓練標記中的不到0.1%。
  • 融合函數:對於融合函數,本文使用最小化交叉熵。同時使用其他的替代融合函數進行消融實驗。
  • 訓練細節:使用128的批量大小和配備8個NVIDIA A100 GPU的單個節點上的最大長度為2048來訓練Llama-2 7B的目標LLM,每個節點有40GB的內存。訓練框架是基於Huggingface Transformers實現的,並使用FlashAttention加速。本文的模型使用AdamW優化器進行優化,β1=0.9,β2=0.95,梯度裁剪設置為1.0,權重衰減為0.1。采用餘弦學習率,最大學習率為1e-5,預熱率為0.008。
  • 評估:本文在三個基準上評估FuseLLM,這三個基準代表了LLM的不同核心功能,即跨越推理常識代碼生成

實驗結果

上表展示了與BBH上的基線方法相比,FuseLLM的總體結果。可以觀察到:

  • 三個源LLM在27個 BBH 任務中表現出不同的性能,Llama-2通常優於其他任務。
  • 在使用緊湊多樣的語料庫進行持續訓練後,與Llama-2相比,Llama-2 CLM顯示出1.86%的相對改進,盡管這種改進在任務之間相對溫和且不一致。
  • 平均而言,FuseLLM在所有27個任務中比原始Llama-2的平均相對性能增益為5.16%。在特定任務中,FuseLLM實現的增強是顯著的。
  • 在像Dick Languages這樣的任務中,簡單的連續預訓練會導致性能下降,FuseLLM利用單個源LLM的組合優勢來恢復性能改進。
  • FuseLLM偶爾會在幾何形狀和單詞排序等任務上表現出性能下降,這可以歸因於兩個原因。首先,除了Llama-2之外,其他源LLM在這些任務上表現不佳,影響了融合結果。其次,連續訓練數據集和下遊任務之間的相關性導致性能下降。

上表展示了FuseLLM和Common Sense (CS)基準上基線方法的零樣本性能。結果表明:

  • FuseLLM在所有五個任務中始終優於基線,與Llama-2相比,相對性能提1.25%。相比之下,Llama-2 CLM 表現出邊際改進,與Llama-2相比,相對增強隻有0.16%。
  • 在具有挑戰性的ARC-challenge(2.40%)和OpenBookQA(2.71%)任務中觀察到Llama-2到FuseLLM的實質性改進,突出了FuseLLM在利用集體知識來解決復雜問題方面的有效性。

對於代碼生成評估,FuseLLM在 MultiPL-E(ME)基準上的零樣本性能如上表所示。觀察到:

  • FuseLLM在10個任務中的9個上優於Llama-2,在特定編程語言(如R)的分數顯著提高,從4.97增加到5.84。
  • 與Llama-2相比,OpenLLaMA和MPT在代碼生成任務中都表現出顯著的性能,FuseLLM的融合結果平均性能增益為6.36%,遠高於Llama-2 CLM中觀察到的1.37%的改進。

融合概率分佈

本文研究了從多個LLM獲得的融合概率分佈的有效性,並跟蹤了訓練過程中性能改進的趨勢。

上圖顯示了Llama-2 CLM和FuseLLM在BBH上不同規模的訓練數據下的few-shot CoT性能的比較。結果表明:

  • 與Llama-2 CLM相比,FuseLLM將精確匹配精度提高了2.5%,並在0.52億個token內實現了Llama-2 CLM的最佳性能。
  • 與Llama-2 CLM所需的15.7億token相比,這意味著token需求減少了3.9倍。

這些結果表明,從LLM導出的概率分佈包含比原始文本序列更容易學習的知識,加速了優化過程。

實現過程分析

本文還對FuseLLM的關鍵元素進行分析包括:源LLM的數量、token對齊標準以及融合函數的選擇。

  • 源LLM的數量:上表給出了融合不同數量LLM的結果。隨著模型數量從1增加到3,FuseLLM的性能有了明顯的改進。在BBH中觀察到持續的性能改進。而在CS或ME中,當融合兩個模型時,優勢更加突出。這一現象可能歸因於三個模型在BBH中的各種任務上的性能差異很大,而CS或ME在任務上的表現差異相對較小。

  • token對齊標準:在LLM的融合過程中,確保來自多個模型的令牌和詞匯表的正確對齊至關重要。上表對兩種對齊標準進行了比較。很明顯,所提出的基於最小編輯距離的MinED方法始終優於EM方法。後者依賴於精確匹配。導致的性能增強是由於MinED能夠放松EM的約束,因為在同一序列中由不同的標記器分離的標記通常表現出微小的差異。

  • 融合函數:最小交叉熵的分佈矩陣和基於交叉熵的分配矩陣的加權平均。兩種融合函數的比較如上表所示。在所有基準測試中,帶有MinCE的FuseLLM始終優於AvgCE。這可歸因於AvgCE中使用的直接加權求和所引入的失真,這可能會削弱單個LLM的獨特優勢。

知識融合vs.知識蒸餾

知識蒸餾技術也可以用來增強LLM的能力,但FuseLLM由於兩個不同的方面而脫穎而出,本文從Llama-2 13B 中提取概率分佈,並應用傳統的知識蒸餾方法將其能力轉移到Llama-2 7B中。如上表所示:

  • 蒸餾模型在所有基準測試中都優於原始的Llama2 7B,證明了知識蒸餾的有效性。
  • 與FuseLLM相比,Llama-2 KD實現的改進相對適中。這表明FuseLLM 通過通過連續訓練集成三個具有不同架構的7B模型來實現的卓越性能超過了簡單地從單個13B模型中提取知識的好處。

知識融合vs.集成/合並

本文進行了實驗,模擬多個LLM來自同一個基本模型,但在不同的語料庫上進行訓練的場景。

上表結果中觀察到,在使用10億個token進行訓練後,原始LLM的能力會轉移到每個特定領域的LLM,導致其他領域的性能下降。雖然所有的融合技術都可以集成不同模型的優勢,但FuseLLM在三個領域中始終實現最低的困惑程度。這突出了它比集合和權重合並方法更有效地利用集體知識的潛力。

總結

在這項研究中,探索了LLM的知識融合領域,以創建一個統一的模型,將多個結構不同的LLM的能力和獨特優勢相結合。並介紹了一種新的方法:FuseLLM,它利用這些源LLM的生成分佈來外部化它們的知識,並將它們用於目標LLM的持續訓練。一系列實驗證明了FuseLLM相對於單個源LLM的優越性,並建立了基線。LLM融合領域成為一種更有前景的探索途徑,特別是考慮到了LLM的不同結構和大量模型大小。