亞洲資本網(wǎng) > 資訊 > 國(guó)內(nèi) > 正文
DeepSpeed ZeRO++:降低4倍網(wǎng)絡(luò)通信,顯著提高大模型及類ChatGPT模型訓(xùn)練效率_世界看點(diǎn)
2023-06-26 05:56:51來(lái)源: 機(jī)器之心


(相關(guān)資料圖)

機(jī)器之心轉(zhuǎn)載

來(lái)源:知乎作者:微軟DeepSpeed大型 AI 模型正在改變數(shù)字世界?;诖笮驼Z(yǔ)言模型 (LLM) 的 Turing-NLG、ChatGPT 和 GPT-4 等生成語(yǔ)言模型用途廣泛,能夠執(zhí)行摘要、代碼生成和翻譯等任務(wù)。同樣,DALL?E、Microsoft Designer 和 Bing Image Creator 等大型多模態(tài)生成模型可以生成藝術(shù)、建筑、視頻和其他數(shù)字資產(chǎn),使內(nèi)容創(chuàng)作者、建筑師和工程師能夠探索全新的創(chuàng)意生產(chǎn)力。 然而,訓(xùn)練這些大型模型需要在數(shù)百甚至數(shù)千個(gè) GPU 設(shè)備上使用大量?jī)?nèi)存和計(jì)算資源。例如,訓(xùn)練 Megatron-Turing NLG 530B 模型需要使用超過(guò) 4,000 個(gè) NVidia A100 GPU。有效地利用這些資源需要一個(gè)復(fù)雜的優(yōu)化系統(tǒng),以將模型合理分配到各個(gè)設(shè)備的內(nèi)存中,并有效地并行化這些設(shè)備上的計(jì)算。同時(shí),為了使深度學(xué)習(xí)社區(qū)能夠輕松進(jìn)行大型模型訓(xùn)練,這些優(yōu)化必須易于使用。 DeepSpeed 的 ZeRO 優(yōu)化系列為這些挑戰(zhàn)提供了強(qiáng)大的解決方案,并已廣泛用于大型深度學(xué)習(xí)模型例如 TNLG-17B、Bloom-176B、MPT-7B、Jurrasic-1 的訓(xùn)練中 。盡管它具有變革性的能力 ,在一些關(guān)鍵場(chǎng)景中,ZeRO 會(huì)在 GPU 之間產(chǎn)生大量數(shù)據(jù)傳輸開(kāi)銷,這降低了訓(xùn)練效率。這種情況特別發(fā)生在以下場(chǎng)景中:a) 全局 batch size 較小,而 GPU 數(shù)量多,這導(dǎo)致每個(gè) GPU 上 batch size 較小,需要頻繁通信;或者 b) 在低端集群上進(jìn)行訓(xùn)練,其中跨節(jié)點(diǎn)網(wǎng)絡(luò)帶寬有限,導(dǎo)致高通信延遲。在這些情況下,ZeRO 的訓(xùn)練效率會(huì)受到限制。 為了解決這些限制,我們發(fā)布了 ZeRO++ 。ZeRO++ 相比 ZeRO 將總通信量減少了 4 倍,而不會(huì)影響模型質(zhì)量。這有兩個(gè)關(guān)鍵意義: 1. ZeRO++ 加速大型模型預(yù)訓(xùn)練和微調(diào)每個(gè) GPU 上 batch size 較小時(shí):無(wú)論是在數(shù)千個(gè) GPU 上預(yù)訓(xùn)練大型模型,還是在數(shù)百個(gè)甚至數(shù)十個(gè) GPU 上對(duì)其進(jìn)行微調(diào),當(dāng)每個(gè) GPU 的 batch size 較小時(shí),ZeRO++ 提供比 ZeRO 高 2.2 倍的吞吐量,直接減少訓(xùn)練時(shí)間和成本。 低帶寬計(jì)算集群: ZeRO++ 使低帶寬集群能夠?qū)崿F(xiàn)與帶寬高 4 倍的高端集群類似的吞吐量。因此,ZeRO++ 可以跨更廣泛的集群進(jìn)行高效的大型模型訓(xùn)練。 2. ZeRO++ 加速 ChatGPT 類的 RLHF 訓(xùn)練雖然 ZeRO++ 主要是為訓(xùn)練而設(shè)計(jì)的,但它的優(yōu)化也自動(dòng)適用于 ZeRO-Inference,因?yàn)橥ㄐ砰_(kāi)銷對(duì)于 ZeRO 的訓(xùn)練和推理同樣適用。因此,ZeRO++ 可以提高人類反饋強(qiáng)化學(xué)習(xí) (RLHF) 等算法的效率,因?yàn)?RLHF 結(jié)合了訓(xùn)練和推理。 通過(guò)與 DeepSpeed-Chat 的集成,與原始 ZeRO 相比,ZeRO++ 可以將 RLHF 訓(xùn)練的生成階段效率提高多達(dá) 2 倍,強(qiáng)化學(xué)習(xí)訓(xùn)練階段效率提高多達(dá) 1.3 倍。 接下來(lái),我們將更深入地解釋 ZeRO 及其通信開(kāi)銷,并討論 ZeRO++ 中為解決這些問(wèn)題而進(jìn)行的關(guān)鍵優(yōu)化。然后我們將展示 ZeRO++ 對(duì)不同模型大小、批量大小和帶寬限制的訓(xùn)練吞吐量的影響。我們還將討論 ZeRO++ 如何應(yīng)用于 DeepSpeed-Chat,以加速使用 RLHF 的對(duì)話模型的訓(xùn)練。 ZeRO++ 詳解

圖2:ZeRO optimizer 工作流程圖(此為部分展示,完整流程請(qǐng)看知乎原文)

ZeRO 是數(shù)據(jù)并行 (Data Parallelism) 的一種內(nèi)存高效版本,其中模型狀態(tài)會(huì)被分割儲(chǔ)存在所有 GPU 上,而不需要在訓(xùn)練期間使用基于 gather/broadcas 的通信進(jìn)行復(fù)制和重建。這使 ZeRO 能夠有效地利用所有設(shè)備的聚合 GPU 內(nèi)存和計(jì)算力,同時(shí)提供簡(jiǎn)單易用的數(shù)據(jù)并行訓(xùn)練。 假設(shè)模型大小為 M。在前向傳播過(guò)程中,ZeRO 執(zhí)行全收集 / 廣播 (all-gather/broadcast) 操作以在需要之時(shí)為每個(gè)模型層收集參數(shù)(總共大小為 M)。在向后傳遞中,ZeRO 對(duì)每一層的參數(shù)采用類似的通信模式來(lái)計(jì)算其局部梯度(總大小為 M)。此外,ZeRO 在對(duì)每個(gè)局部梯度計(jì)算完畢后會(huì)立刻使用 reduce 或 reduce-scatter 通信進(jìn)行平均和分割儲(chǔ)存(總大小為 M)。因此,ZeRO 總共有 3M 的通信量,平均分布在兩個(gè)全收集 / 廣播 (all-gather/broadcast) 和一個(gè)減少分散 / 減少 (reduce-scatter/reduce) 操作中。 為了減少這些通信開(kāi)銷,ZeRO++ 進(jìn)行了三組通信優(yōu)化,分別針對(duì)上述三個(gè)通信集合: 圖 3:qwZ 的分區(qū)量化圖例ZeRO 通信過(guò)程中的權(quán)重量化 (qwZ)首先,為了減少 all-gather 期間的參數(shù)通信量,我們采用權(quán)重量化在通信前將每個(gè)模型參數(shù)從 FP16(兩個(gè)字節(jié))動(dòng)態(tài)縮小為 INT8(一個(gè)字節(jié))數(shù)據(jù)類型,并在通信后對(duì)權(quán)重進(jìn)行反量化。然而,簡(jiǎn)單地對(duì)權(quán)重進(jìn)行量化會(huì)降低模型訓(xùn)練的準(zhǔn)確性。為了保持良好的模型訓(xùn)練精度,我們采用分區(qū)量化,即對(duì)模型參數(shù)的每個(gè)子集進(jìn)行獨(dú)立量化。目前尚且沒(méi)有針對(duì)分區(qū)量化的高性能現(xiàn)有實(shí)現(xiàn)。因此,我們自行從頭開(kāi)始實(shí)現(xiàn)了一套高度優(yōu)化的量化 CUDA 內(nèi)核,與基本量化相比,精度提高 3 倍,速度提高 5 倍。 圖 4: 權(quán)重的分層分割存儲(chǔ) (hpZ)ZeRO 模型權(quán)重的分層分割存儲(chǔ) (hpZ)其次,為了減少向后傳遞期間全收集 (all-gather) 權(quán)重的通信開(kāi)銷,我們用 GPU 內(nèi)存進(jìn)行通信。更具體地說(shuō),我們不像在 ZeRO 中那樣將整個(gè)模型權(quán)重分布在所有機(jī)器上,而是在每臺(tái)機(jī)器中維護(hù)一個(gè)完整的模型副本。以更高的內(nèi)存開(kāi)銷為代價(jià),這允許我們用機(jī)器內(nèi)的模型權(quán)重全收集 / 廣播 (all-gather/broadcast) 代替昂貴的跨機(jī)器全收集 / 廣播 (all-gather/broadcast),由于機(jī)器內(nèi)通信帶寬更高,這使得通信速度大幅提升。

圖 5: qgZ 端到端的工作流程

ZeRO 通信過(guò)程中梯度量化 (qgZ)第三,要降低梯度的 reduce-scatter 通信成本更具挑戰(zhàn)性。因?yàn)橹苯討?yīng)用量化來(lái)減少通信量是不可行的。即使我們使用分區(qū)量化來(lái)降低量化誤差,梯度 reduce 也會(huì)累積并放大量化誤差。為了解決這個(gè)問(wèn)題,我們只在通信之前量化梯度,但在任何 reduce 操作之前將它們反量化到原有精度。為了有效地做到這一點(diǎn),我們發(fā)明了一種名為 qgZ 的基于 all-to-all 的新型量化梯度通信范式,它在功能上等同于壓縮的歸約 - 分散 (reduce-scatter) 操作。 qgZ 旨在解決兩個(gè)挑戰(zhàn):i) 如果我們簡(jiǎn)單地在 INT4/INT8 中實(shí)施 reduce-scatter 會(huì)導(dǎo)致顯著精度損失,以及 ii) 在傳統(tǒng) tree 或 ring-based reduce-scatter 中使用量化需要一長(zhǎng)串量化和反量化步驟,這直接導(dǎo)致誤差積累和顯著的延遲,即使我們?cè)谌壬线M(jìn)行 reduce。為了解決這兩個(gè)挑戰(zhàn),qgZ 不使用 tree 或 ring-based reduce-scatter 算法,而是基于一種新穎的分層 all-to-all 方法。 qgZ 中有三個(gè)主要步驟: 梯度切片重新排序; 節(jié)點(diǎn)內(nèi)通信和 reduce; 節(jié)點(diǎn)間通信和 reduce。 首先,在任何通信發(fā)生之前,我們對(duì)梯度進(jìn)行切片并對(duì)張量切片重新排序,以保證通信結(jié)束時(shí)每個(gè) GPU 上的最終梯度位置(即圖 5 中的綠色塊)是正確的。其次,我們量化重新排序的梯度切片,在每個(gè)節(jié)點(diǎn)內(nèi)進(jìn)行 all-to-all 通信,從 all-to-all 中對(duì)接收到的梯度切片進(jìn)行反量化,并進(jìn)行局部 reduce。第三,我們?cè)俅瘟炕植?reduce 后的梯度,進(jìn)行節(jié)點(diǎn)間的 all-to-all 通信,再次對(duì)接收到的梯度進(jìn)行反量化,并計(jì)算最終的高精度梯度 reduce,得到圖 5 中綠色塊的結(jié)果。 這種分層方法的原因是為了減少跨節(jié)點(diǎn)通信量。更準(zhǔn)確地說(shuō),給定每個(gè)節(jié)點(diǎn) N 個(gè) GPU、M 的模型大小和 Z 的量化比率,單跳 all-to-all 將生成 M*N/Z 跨節(jié)點(diǎn)流量。相比之下,通過(guò)這種分層方法,我們將每個(gè) GPU 的跨節(jié)點(diǎn)流量從 M/Z 減少到 M/(Z*N)。因此,總通信量從 M*N/Z 減少到 M*N/(Z*N) = M/Z。我們通過(guò)重疊節(jié)點(diǎn)內(nèi)和節(jié)點(diǎn)間通信以及融合 CUDA 內(nèi)核來(lái)進(jìn)一步優(yōu)化 qgZ 的端到端延遲(張量切片重新排序 (Tensor Slice Reordering)+ 節(jié)點(diǎn)內(nèi)量化 (Intra-node quantization))和(節(jié)點(diǎn)內(nèi)反量化 (Intra-node Dequantization) + 節(jié)點(diǎn)內(nèi)梯度整合 (Intra-node Reduction) + 節(jié)點(diǎn)間量化 (inter-node quantization))。 通信總量?jī)?yōu)化通過(guò)結(jié)合以上所有三個(gè)組件,我們將跨節(jié)點(diǎn)通信量從 3M 減少到 0.75M。更具體地說(shuō),我們使用 qwZ 將模型權(quán)重的前向全收集 / 廣播從 M 減少到 0.5M。我們使用 hpZ 消除了反向傳播期間的跨節(jié)點(diǎn) all-gather,將通信從 M 減少到 0。最后,我們使用 qgZ 將反 向傳播期間的跨節(jié)點(diǎn) reduce-scatter 通信從 M 減少到 0.25M。 ZeRO++ 加速大型語(yǔ)言模型訓(xùn)練在這里,我們展示了 ZeRO++ 在 384 個(gè) Nvidia V100 GPU 上的真實(shí) LLM 訓(xùn)練場(chǎng)景的測(cè)試結(jié)果。 圖 6: 在 384 個(gè) V100 GPU 上的各種模型大小下 ZeRO++ 與 ZeRO 的吞吐量,節(jié)點(diǎn)間使用 4 個(gè) Infiniband (IB) 進(jìn)行互連,每個(gè)以 100 Gbps 運(yùn)行。在 GPU 小 batch size 情況下 ZeRO++ 實(shí)現(xiàn)更高的訓(xùn)練效率高帶寬集群:如圖 6 所示,我們首先展示了 ZeRO++ 相對(duì)于 ZeRO 的吞吐量改進(jìn),針對(duì)不同的模型大小和微批量 (micro-batch size) 大小,測(cè)試使用 4x Infiniband (IB) 以實(shí)現(xiàn) 400Gbps 跨節(jié)點(diǎn)互連帶寬,每個(gè)以 100Gbps 運(yùn)行。在 micro-batch size 為每 GPU 1k tokens 時(shí),ZeRO++ 比 ZeRO-3 的吞吐量提高了 28% 到 36%。對(duì)于 2k tokens micro-batch size 大小,ZeRO++ 比 ZeRO-3 實(shí)現(xiàn)了 24% 到 29% 的吞吐量增益。 圖 7: 在 384 個(gè) V00 GPU 上 100Gbps 跨節(jié)點(diǎn)帶寬時(shí)各種 LLM 的吞吐量低帶寬集群:在 100Gbps 等低帶寬網(wǎng)絡(luò)環(huán)境中,ZeRO++ 的性能明顯優(yōu)于 ZeRO-3。如圖 7 所示,與 ZeRO-3 相比,ZeRO++ 在端到端吞吐量方面實(shí)現(xiàn)了高達(dá) 2.2 倍的加速。平均而言,ZeRO++ 比 ZeRO-3 基線實(shí)現(xiàn)了大約 2 倍的加速。 圖 8: ZeRO++ 以顯著降低的帶寬實(shí)現(xiàn)高帶寬集群性能實(shí)現(xiàn)高帶寬 ZeRO 和低帶寬 ZeRO++ 集群之間的模型訓(xùn)練效率等效此外,與 ZeRO 在高得多的帶寬環(huán)境下相比,ZeRO ++ 可以在低帶寬集群中實(shí)現(xiàn)相當(dāng)?shù)南到y(tǒng)吞吐量。如圖 8 所示,對(duì)于 18B 和 138B 模型大小,具有 200Gbps 跨節(jié)點(diǎn)帶寬的 ZeRO++ 可以達(dá)到與 800Gbps 跨節(jié)點(diǎn)帶寬的 ZeRO-3 相似的 TFLOP。 鑒于 ZeRO++ 出色的可擴(kuò)展性,我們將 ZeRO++ 視為用于訓(xùn)練大型 AI 模型的下一代 ZeRO。 DeepSpeed-Chat 與 ZeRO++ 結(jié)合用于 RLHF 訓(xùn)練RLHF 訓(xùn)練簡(jiǎn)介ChatGPT 類模型由 LLM 提供支持,并使用 RLHF 進(jìn)行微調(diào)。RLHF 由生成(推理)階段和訓(xùn)練階段組成。在生成階段,演員 (actor) 模型將部分對(duì)話作為輸入,并使用一系列前向傳遞生成響應(yīng)。然后在訓(xùn)練階段,評(píng)論 (critic) 模型根據(jù)質(zhì)量對(duì)生成的響應(yīng)進(jìn)行排名,為演員模型提供強(qiáng)化信號(hào)。使用這些排名對(duì)參與者模型進(jìn)行微調(diào),使其能夠在后續(xù)迭代中生成更準(zhǔn)確和適當(dāng)?shù)捻憫?yīng)。 RLHF 訓(xùn)練帶來(lái)了巨大的內(nèi)存壓力,因?yàn)樗褂昧怂姆N模型(演員、參考、評(píng)論、獎(jiǎng)勵(lì))。常見(jiàn)的解決方案是采用低秩自適應(yīng)訓(xùn)練 (LoRA) 來(lái)解決 RLHF 的內(nèi)存壓力。LoRA 凍結(jié)了預(yù)訓(xùn)練模型的權(quán)重,并將可訓(xùn)練的秩分解矩陣注入到 Transformer 架構(gòu)的每一層中,顯著減少了可訓(xùn)練參數(shù)的數(shù)量。LoRA 通過(guò)減少內(nèi)存使用來(lái)加速 RLHF,允許更大的批處理 (batch) 大小,從而大大提高吞吐量。 DeepSpeed-Chat with ZeRO++ 用于 RLHF 訓(xùn)練圖 9: ZeRO++ 加速了 RLHF 訓(xùn)練的生成和訓(xùn)練階段ZeRO++ 在 RLHF + LoRA 的場(chǎng)景下有著獨(dú)特的應(yīng)用,因?yàn)榇蠖鄶?shù)模型權(quán)重都被凍結(jié)了。這意味著 ZeRO++ 可以將這些凍結(jié)的權(quán)重量化保存到 INT4/8 中,而不是將它們存儲(chǔ)在 fp16 中并在每次通信操作之前對(duì)其進(jìn)行量化。通信后的反量化仍然是為了讓權(quán)重為計(jì)算做好準(zhǔn)備,但反量化后的權(quán)重在計(jì)算后被簡(jiǎn)單地丟棄。 以這種方式使用 ZeRO++ 進(jìn)行 RLHF 訓(xùn)練可以減少內(nèi)存使用和通信量。這意味著通過(guò)減少通信以及由于減少內(nèi)存使用而啟用更大的批處理大小來(lái)提高訓(xùn)練吞吐量。在生成階段,ZeRO++ 使用 hpZ 將所有權(quán)重通信保持在每個(gè)節(jié)點(diǎn)內(nèi),以利用更高的節(jié)點(diǎn)內(nèi)通信帶寬,減少通信量,進(jìn)一步提高生成吞吐量。 ZeRO++ 已集成到 DeepSpeed-Chat 中,以支持 ChatGPT 類模型的 RLHF 訓(xùn)練。在圖 9 中,我們比較了不同大小的 actor 模型的 RLHF 生成吞吐量。測(cè)試配置為 32 個(gè) V100 GPU ,actor 模型大小為 30B 和 66B 以測(cè)試 ZeRO 和 ZeRO++ 性能。結(jié)果表明,ZeRO++ 的 RLHF 生成吞吐量比 ZeRO 高出 2.25 倍。我們還展示了在 16 個(gè) V100 GPU 上訓(xùn)練階段的加速,其中 ZeRO++ 實(shí)現(xiàn)了比 ZeRO 高 1.26 倍的吞吐量,這是由于 ZeRO++ 支持的更低通信量和更大批量大小。 DeepSpeed ZeRO++ 現(xiàn)已發(fā)布!我們非常高興能夠發(fā)布 DeepSpeed ZeRO++ 并讓 AI 社區(qū)中的每個(gè)人都可以使用它。請(qǐng)?jiān)L問(wèn)我們的 GitHub 頁(yè)面以獲取 LLM 訓(xùn)練教程。用于 DeepSpeed-Chat 的 ZeRO++ 將在未來(lái)幾周內(nèi)發(fā)布。有關(guān) ZeRO++ 的更多技術(shù)細(xì)節(jié),請(qǐng)查看我們的 arxiv 論文。 DeepSpeed-ZeRO++ 是 DeepSpeed 生態(tài)系統(tǒng)的一部分。要了解更多信息,請(qǐng)?jiān)L問(wèn)我們的網(wǎng)站,在那里您可以找到詳細(xì)的博客文章、教程和有用的文檔。 您還可以在我們的英文 Twitter、日文 Twitter 和中文知乎 上獲取最新的 DeepSpeed 新聞。 DeepSpeed 歡迎您的貢獻(xiàn)!我們鼓勵(lì)您在 DeepSpeed GitHub 頁(yè)面上報(bào)告問(wèn)題、貢獻(xiàn) PR 并加入討論。有關(guān)更多詳細(xì)信息,請(qǐng)參閱我們的貢獻(xiàn)指南。我們對(duì)與大學(xué)、研究實(shí)驗(yàn)室和公司的合作持開(kāi)放態(tài)度。對(duì)于此類請(qǐng)求(以及其他不適合 GitHub 的請(qǐng)求),請(qǐng)直接發(fā)送電子郵件至 deepspeed-info@microsoft.com。 貢獻(xiàn)者:DeepSpeed 團(tuán)隊(duì)的以下人員的貢獻(xiàn)使該項(xiàng)目成為可能: Guanhua Wang, Heyang Qin, Sam Ade Jacobs, Connor Holmes, Samyam Rajbhandari, Olatunji Ruwase, Ammar Ahmad Awan, Jeff Rasley, Michael Wyatt, Yuxiong He (team lead) 本文轉(zhuǎn)載自微軟DeepSpeed組。 官方知乎賬號(hào):zhihu.com/people/deepspeed ?THE END 轉(zhuǎn)載請(qǐng)聯(lián)系本公眾號(hào)獲得授權(quán) 投稿或?qū)で髨?bào)道:content@jiqizhixin.com

關(guān)鍵詞:

專題新聞
  • 焦點(diǎn)快播:知名書(shū)畫(huà)家殺害女子被通緝?警方回應(yīng)
  • 推進(jìn)算力互聯(lián)互通戰(zhàn)略落地 聞庫(kù)提出三點(diǎn)建議|全球今頭條
  • 浙江省高考2023什么時(shí)候出成績(jī)(附查詢方式)
  • 天天速看:越南告急,中國(guó)出手!
  • 今亮點(diǎn)!AMDNavi32RDNA3GPU封裝如圖所示支持Navi31GPU芯片具有六個(gè)緊湊型MCD
  • 焦點(diǎn)簡(jiǎn)訊:銀川市發(fā)布雷電黃色預(yù)警!?有短時(shí)強(qiáng)降水、雷暴大風(fēng)
最近更新

京ICP備2021034106號(hào)-51

Copyright © 2011-2020  亞洲資本網(wǎng)   All Rights Reserved. 聯(lián)系網(wǎng)站:55 16 53 8 @qq.com