[發明專利]一種模型訓練方法和相關裝置在審
| 申請號: | 202210427939.0 | 申請日: | 2022-04-22 |
| 公開(公告)號: | CN115114927A | 公開(公告)日: | 2022-09-27 |
| 發明(設計)人: | 弓靜 | 申請(專利權)人: | 騰訊科技(深圳)有限公司 |
| 主分類號: | G06F40/30 | 分類號: | G06F40/30;G06F40/289;G06N3/04;G06N3/08 |
| 代理公司: | 深圳市深佳知識產權代理事務所(普通合伙) 44285 | 代理人: | 林志鵬 |
| 地址: | 518057 廣東省深圳*** | 國省代碼: | 廣東;44 |
| 權利要求書: | 查看更多 | 說明書: | 查看更多 |
| 摘要: | |||
| 搜索關鍵詞: | 一種 模型 訓練 方法 相關 裝置 | ||
本申請實施例公開了一種模型訓練方法和相關裝置,至少涉及人工智能模型中的機器學習,確定待訓練模型包括的m個張量與n個并行進程之間的對應關系,m個張量包括在n個張量集合中,每個張量集合包括m個張量中的部分張量,n個張量集合與n個并行進程的對應關系為一一對應關系,使得每個并行進程只維護部分張量。目標并行進程與目標張量具有對應關系,在進行迭代的過程中,目標并行進程僅基于目標張量更新待訓練模型的參數,根據更新后的參數訓練待訓練模型。不僅降低了創建臨時緩存的數量,還降低了臨時緩存的頻繁創建和釋放產生的內存碎片。由此,通過每個并行進程至維護部分張量,降低了激活層內存、臨時緩存等,進而降低了模型的顯存占用。
技術領域
本申請涉及計算機技術領域,特別是涉及一種模型訓練方法和相關裝置。
背景技術
隨著人工智能的發展,模型逐漸朝著更大量級發展,如量級越大的自然語言模型的準確率更高,例如,生成型已訓練變換模型3(Generative Pre-trained Transformer 3,GPT-3)的模型參數已達到175B。
在預訓練階段,較大的模型需要占用的顯存較多。
發明內容
為了解決上述技術問題,本申請提供了一種模型訓練方法和相關裝置,用于降低訓練模型的顯存占用。
本申請實施例公開了如下技術方案:
一方面,本申請實施例提供一種模型訓練方法,所述方法包括:
確定待訓練模型包括的m個張量與n個并行進程之間的對應關系;其中,所述m個張量包括在n個張量集合中,每個張量集合包括所述m個張量中的部分張量,所述n個張量集合與所述n個并行進程的對應關系為一一對應關系,所述張量為所述待訓練模型包括的多層網絡的輸入和輸出,m和n為大于1的整數;
針對所述n個并行進程中的目標并行進程,基于與所述目標并行進程具有對應關系的目標張量集合更新所述待訓練模型的參數;
根據更新后的參數訓練所述待訓練模型。
另一方面,本申請實施例提供一種模型訓練裝置,所述裝置包括:確定單元、更新單元和訓練單元;
所述確定單元,用于待訓練模型包括的m個張量與n個并行進程之間的對應關系;其中,所述m個張量包括在n個張量集合中,每個張量集合包括所述多個張量中的部分張量,所述n個張量集合與所述n個并行進程的對應關系為一一對應關系,所述張量為所述待訓練模型包括的多層網絡的輸入和輸出,m和n為大于1的整數;
所述更新單元,用于針對所述n個并行進程中的目標并行進程,基于與所述目標并行進程具有對應關系的目標張量集合更新所述待訓練模型的參數;
所述訓練單元,用于根據更新后的參數訓練所述待訓練模型。
另一方面,本申請實施例提供一種計算機設備,所述設備包括處理器以及存儲器:
所述存儲器用于存儲程序代碼,并將所述程序代碼傳輸給所述處理器;
所述處理器用于根據所述程序代碼中的指令執行上述方面所述的方法。
另一方面,本申請實施例提供了一種計算機可讀存儲介質,所述計算機可讀存儲介質用于存儲計算機程序,所述計算機程序用于執行上述方面所述的方法。
另一方面,本申請實施例提供了一種計算機程序產品或計算機程序,該計算機程序產品或計算機程序包括計算機指令,該計算機指令存儲在計算機可讀存儲介質中。計算機設備的處理器從計算機可讀存儲介質讀取該計算機指令,處理器執行該計算機指令,使得該計算機設備執行上述方面所述的方法。
該專利技術資料僅供研究查看技術是否侵權等信息,商用須獲得專利權人授權。該專利全部權利屬于騰訊科技(深圳)有限公司,未經騰訊科技(深圳)有限公司許可,擅自商用是侵權行為。如果您想購買此專利、獲得商業授權和技術合作,請聯系【客服】
本文鏈接:http://www.szxzyx.cn/pat/books/202210427939.0/2.html,轉載請聲明來源鉆瓜專利網。
- 上一篇:一種電動汽車雙主體協同優化方法及裝置
- 下一篇:一種藥材粉碎研磨一體裝置





