PyTorch 2.2大更新!集成FlashAttention-2,性能提升2倍

2024年2月6日 28点热度 0人点赞

編輯:alan

【新智元導讀】新的一年,PyTorch也迎來了重大更新,PyTorch 2.2集成了FlashAttention-2和AOTInductor等新特性,計算性能翻倍。

新的一年,PyTorch也迎來了重大更新!

繼去年十月份的PyTorch大會發佈了2.1版本之後,全世界各地的521位開發者貢獻了3628個提交,由此形成了最新的PyTorch 2.2版本。

新的版本集成了FlashAttention-2,使得
scaled_dot_product_attention (SDPA)相較於之前的版本有了約2倍的性能提升。

PyTorch 2.2還引入了一個新的TorchInductor提前擴展,稱為 AOTInductor,旨在為非python服務器端編譯和部署PyTorch程序。

PyTorch中的torch.distributed支持了一個叫做device_mesh的新抽象,用於初始化和表示ProcessGroups。

另外,PyTorch 2.2提供了一個標準化的、可配置的日志記錄機制,——TORCH_LOGS。

PyTorch 2.2還對torch.compile做了許多改進,包括改進了對編譯優化器的支持,以及TorchInductor融合和佈局優化。

最後值得註意的是,PyTorch將放棄對macOS x86的支持,PyTorch 2.2.x是支持macOS x64的最後一個版本。

PyTorch 2.2新特性

首先請註意,如果從源代碼構建PyTorch 2.2,需要GCC 9.4或更高版本,PyTorch 代碼庫已從C 14遷移到C 17。

FlashAttention-2

FlashAttention-2通過優化GPU上不同線程塊和warps之間的工作分區,來解決占用率低或不必要的共享內存讀寫。

FlashAttention-2調整了算法以減少非matmul的計算量,同時提升了Attention計算的並行性(即使是單個頭,也可以跨不同的線程塊,以增加占用率),在每個線程塊中,優化warps之間的工作分配,以減少通過共享內存的通信。

PyTorch 2.2將FlashAttention內核更新到了v2版本,不過需要註意的是,之前的Flash Attention內核具有Windows實現,Windows用戶可以強制使用sdp_kernel,僅啟用Flash Attention的上下文管理器。

而在2.2中,如果必須使用 sdp_kernel 上下文管理器,請使用memory efficient或math內核(在Windows上)。

在FlashAttention-2的加持之下,torch.nn.functional.scaled_dot_product_attention的速度提升了大約2倍,在A100 GPU上達到了理論計算峰值的50%-73%。

AOTInductor

AOTInductor是TorchInductor的擴展,用於處理導出的PyTorch模型,對其進行優化,並生成共享庫以及其他相關工件。

這些編譯的工件可以部署在非Python環境中,經常用於服務器端的推理。

下面的示例演示了如何調用 aot_compile 將模型轉換為共享庫。

AOTInductor支持與Inductor相同的後端,包括CUDA、ROCm和CPU。

TORCH_LOGS

PyTorch 2.2提供了一個標準化的、可配置的日志記錄機制,可用於分析各種子系統的狀態,例如編譯和分佈式操作

可以通過TORCH_LOGS環境變量啟用日志。比如通過在命令行中修改環境變量:

將TorchDynamo的日志級別設置為logging.ERROR,將TorchInductor的日志級別設置為logging.DEBUG。

當然也可以在代碼中以API的形式使用:

torch.distributed.device_mesh

PyTorch 2.2引入了一個新的抽象,用於表示分佈式並行中涉及的 ProcessGroups,稱為
torch.distributed.device_mesh。

為分佈式訓練設置分佈式通信器(NCCL)是一件麻煩的事情。用戶需要編寫不同並行度的工作負載,並為每個並行度手動設置和管理NCCL通信器(ProcessGroup )。

這個過程可能很復雜,容易出錯。而DeviceMesh 可以簡化此過程,使其更易於管理。

DeviceMesh 是管理 ProcessGroup 的更高級別的抽象。它允許用戶毫不費力地創建節點間和節點內進程組,而不必擔心如何為不同的子進程組正確設置等級。

例如,數組的其中一個維度可以表示FSDP中的數據並行(data parallelism),而另一個維度可以表示FSDP中的張量並行(tensor parallelism)。

用戶還可以通過 DeviceMesh 輕松管理底層process_groups,以實現多維並行。

DeviceMesh在處理多維並行性(如3D並行)時很有用。如上圖所示,當你的並行解決方案需要跨主機和每個主機內部進行通信時,可以創建一個2D網格,用於連接每個主機中的設備,並以同構設置將每個設備與其他主機上的對應設備連接起來。

借助 init_device_mesh() ,我們可以在短短兩行內完成上面這個2D設置:

而如果不使用DeviceMesh,我們大概需要自己寫下面這一堆代碼:

當然,如果需要,我們仍然可以訪問底層 ProcessGroup:

優化器的改進

大概有以下幾點:

編譯優化器在所有基準測試中都提高了性能:HuggingFace 18%、TorchBench 19%、TIMM 8% E2E;

編譯的優化器增加對cudagraphs的支持;

對測試套件中所有模型進行平均,每個測試套件的基準測試平均編譯時間增加約40秒;正在進行的優化可能會將其降低到30秒以下。

用於多張量優化器編譯的inductor中缺少的主要功能是foreach算子的高效編碼生成。

在調度器內部,將所有在下放過程中註冊的緩沖區列表凝聚到
ForeachKernelSchedulerNodes中(FusedSchedulerNode的子類)。

為了檢查融合是否合法,每個內部 SchedulerNode 執行的寫操作必須與消費SchedulerNode在同一列表索引處的讀操作相匹配。

此外,正常的垂直融合規則必須允許在消費者和生產者SchedulerNode列表的每個索引處進行融合。

如果滿足了這些條件,
ForeachKernelSchedulerNode將垂直融合成一個
ForeachKernelSchedulerNode,其中每個列表上的相應點操作都將被融合。

通過實現這種融合,可以將一系列 foreach 運算融合到單個內核中,從而實現多張量優化器的完全融合。

性能改進

TorchInductor中添加了許多性能優化,包括對torch.concat的水平融合支持、改進的卷積佈局優化、以及改進
scaled_dot_product_attention模式匹配。

PyTorch 2.2還包括aarch64的許多性能增強,包括對mkldnn權重預打包的支持、改進的ideep基元緩存,以及通過對OneDNN的固定格式內核改進,來提高推理速度。

參考資料:

https://pytorch.org/blog/pytorch2-2/