編輯:alan
【新智元導讀】新的一年,PyTorch也迎來了重大更新,PyTorch 2.2集成了FlashAttention-2和AOTInductor等新特性,計算性能翻倍。
新的一年,PyTorch也迎來了重大更新!
繼去年十月份的PyTorch大會發佈了2.1版本之後,全世界各地的521位開發者貢獻了3628個提交,由此形成了最新的PyTorch 2.2版本。
![](https://news.xinpengboligang.com/upload/keji/9ef9f49cb7c7da98dd6a5f70aae39105.jpeg)
新的版本集成了FlashAttention-2,使得
scaled_dot_product_attention (SDPA)相較於之前的版本有了約2倍的性能提升。
PyTorch 2.2還引入了一個新的TorchInductor提前擴展,稱為 AOTInductor,旨在為非python服務器端編譯和部署PyTorch程序。
PyTorch中的torch.distributed支持了一個叫做device_mesh的新抽象,用於初始化和表示ProcessGroups。
![](https://news.xinpengboligang.com/upload/keji/ffcc2d5a77211dc1a0529d0977bb2eda.jpeg)
另外,PyTorch 2.2提供了一個標準化的、可配置的日志記錄機制,——TORCH_LOGS。
PyTorch 2.2還對torch.compile做了許多改進,包括改進了對編譯優化器的支持,以及TorchInductor融合和佈局優化。
![](https://news.xinpengboligang.com/upload/keji/6a39c01f8c3b4e3f80945b3317898ea0.jpeg)
最後值得註意的是,PyTorch將放棄對macOS x86的支持,PyTorch 2.2.x是支持macOS x64的最後一個版本。
PyTorch 2.2新特性
首先請註意,如果從源代碼構建PyTorch 2.2,需要GCC 9.4或更高版本,PyTorch 代碼庫已從C 14遷移到C 17。
![](https://news.xinpengboligang.com/upload/keji/eddbe1531150ad1711d23cb8d3bc7e01.jpeg)
FlashAttention-2
FlashAttention-2通過優化GPU上不同線程塊和warps之間的工作分區,來解決占用率低或不必要的共享內存讀寫。
![](https://news.xinpengboligang.com/upload/keji/fa75ecbac1920c5a6a94066f23bb03a1.jpeg)
FlashAttention-2調整了算法以減少非matmul的計算量,同時提升了Attention計算的並行性(即使是單個頭,也可以跨不同的線程塊,以增加占用率),在每個線程塊中,優化warps之間的工作分配,以減少通過共享內存的通信。
PyTorch 2.2將FlashAttention內核更新到了v2版本,不過需要註意的是,之前的Flash Attention內核具有Windows實現,Windows用戶可以強制使用sdp_kernel,僅啟用Flash Attention的上下文管理器。
而在2.2中,如果必須使用 sdp_kernel 上下文管理器,請使用memory efficient或math內核(在Windows上)。
在FlashAttention-2的加持之下,torch.nn.functional.scaled_dot_product_attention的速度提升了大約2倍,在A100 GPU上達到了理論計算峰值的50%-73%。
AOTInductor
AOTInductor是TorchInductor的擴展,用於處理導出的PyTorch模型,對其進行優化,並生成共享庫以及其他相關工件。
這些編譯的工件可以部署在非Python環境中,經常用於服務器端的推理。
下面的示例演示了如何調用 aot_compile 將模型轉換為共享庫。
![](https://news.xinpengboligang.com/upload/keji/69bcbae27dc67bc42f1a5a50379b0a58.jpeg)
AOTInductor支持與Inductor相同的後端,包括CUDA、ROCm和CPU。
TORCH_LOGS
PyTorch 2.2提供了一個標準化的、可配置的日志記錄機制,可用於分析各種子系統的狀態,例如編譯和分佈式操作
可以通過TORCH_LOGS環境變量啟用日志。比如通過在命令行中修改環境變量:
將TorchDynamo的日志級別設置為logging.ERROR,將TorchInductor的日志級別設置為logging.DEBUG。
當然也可以在代碼中以API的形式使用:
![](https://news.xinpengboligang.com/upload/keji/9564a258daa6c9f402de686beb932039.jpeg)
torch.distributed.device_mesh
PyTorch 2.2引入了一個新的抽象,用於表示分佈式並行中涉及的 ProcessGroups,稱為
torch.distributed.device_mesh。
為分佈式訓練設置分佈式通信器(NCCL)是一件麻煩的事情。用戶需要編寫不同並行度的工作負載,並為每個並行度手動設置和管理NCCL通信器(ProcessGroup )。
這個過程可能很復雜,容易出錯。而DeviceMesh 可以簡化此過程,使其更易於管理。
DeviceMesh 是管理 ProcessGroup 的更高級別的抽象。它允許用戶毫不費力地創建節點間和節點內進程組,而不必擔心如何為不同的子進程組正確設置等級。
例如,數組的其中一個維度可以表示FSDP中的數據並行(data parallelism),而另一個維度可以表示FSDP中的張量並行(tensor parallelism)。
用戶還可以通過 DeviceMesh 輕松管理底層process_groups,以實現多維並行。
![](https://news.xinpengboligang.com/upload/keji/04c4812f4e18c057bcc377c24bfcada9.jpeg)
DeviceMesh在處理多維並行性(如3D並行)時很有用。如上圖所示,當你的並行解決方案需要跨主機和每個主機內部進行通信時,可以創建一個2D網格,用於連接每個主機中的設備,並以同構設置將每個設備與其他主機上的對應設備連接起來。
借助 init_device_mesh() ,我們可以在短短兩行內完成上面這個2D設置:
而如果不使用DeviceMesh,我們大概需要自己寫下面這一堆代碼:
![](https://news.xinpengboligang.com/upload/keji/7909ff5e77b141b162f62fc45e3257f1.jpeg)
當然,如果需要,我們仍然可以訪問底層 ProcessGroup:
優化器的改進
大概有以下幾點:
編譯優化器在所有基準測試中都提高了性能:HuggingFace 18%、TorchBench 19%、TIMM 8% E2E;
編譯的優化器增加對cudagraphs的支持;
對測試套件中所有模型進行平均,每個測試套件的基準測試平均編譯時間增加約40秒;正在進行的優化可能會將其降低到30秒以下。
用於多張量優化器編譯的inductor中缺少的主要功能是foreach算子的高效編碼生成。
在調度器內部,將所有在下放過程中註冊的緩沖區列表凝聚到
ForeachKernelSchedulerNodes中(FusedSchedulerNode的子類)。
為了檢查融合是否合法,每個內部 SchedulerNode 執行的寫操作必須與消費SchedulerNode在同一列表索引處的讀操作相匹配。
![](https://news.xinpengboligang.com/upload/keji/af8a36a2ca305b6d64d9e5f8ee5a5602.jpeg)
此外,正常的垂直融合規則必須允許在消費者和生產者SchedulerNode列表的每個索引處進行融合。
如果滿足了這些條件,
ForeachKernelSchedulerNode將垂直融合成一個
ForeachKernelSchedulerNode,其中每個列表上的相應點操作都將被融合。
通過實現這種融合,可以將一系列 foreach 運算融合到單個內核中,從而實現多張量優化器的完全融合。
性能改進
TorchInductor中添加了許多性能優化,包括對torch.concat的水平融合支持、改進的卷積佈局優化、以及改進
scaled_dot_product_attention模式匹配。
![](https://news.xinpengboligang.com/upload/keji/754b03124ce2b909ae5d11944cea22f2.jpeg)
PyTorch 2.2還包括aarch64的許多性能增強,包括對mkldnn權重預打包的支持、改進的ideep基元緩存,以及通過對OneDNN的固定格式內核改進,來提高推理速度。
參考資料:
https://pytorch.org/blog/pytorch2-2/