[發明專利]一種神經網絡模型訓練方法、裝置、設備及存儲介質在審
| 申請號: | 202110181147.5 | 申請日: | 2021-02-09 |
| 公開(公告)號: | CN114943331A | 公開(公告)日: | 2022-08-26 |
| 發明(設計)人: | 熊凱 | 申請(專利權)人: | 廣州視源電子科技股份有限公司;廣州視源人工智能創新研究院有限公司 |
| 主分類號: | G06N3/08 | 分類號: | G06N3/08;G06K9/62 |
| 代理公司: | 北京品源專利代理有限公司 11332 | 代理人: | 孟金喆 |
| 地址: | 510530 廣*** | 國省代碼: | 廣東;44 |
| 權利要求書: | 查看更多 | 說明書: | 查看更多 |
| 摘要: | |||
| 搜索關鍵詞: | 一種 神經網絡 模型 訓練 方法 裝置 設備 存儲 介質 | ||
本發明公開了一種神經網絡模型訓練方法、裝置、設備及存儲介質。將第一訓練樣本和第二訓練樣本輸入待訓練的神經網絡模型中進行處理,得到用于預測第一訓練樣本屬于各類別的概率向量,以及由第二訓練樣本的特征向量組成的特征矩陣,基于概率向量計算一批第一訓練樣本的分類損失,以及基于特征矩陣的低秩約束計算一批第二訓練樣本的正則化損失,通過在分類損失中加入基于低秩約束的正則化損失,能夠將人對這個神經網絡模型的先驗知識融入到模型的學習當中,引導模型學習出更加緊湊和更具判別性的特征,同時降低了神經網絡模型的復雜度,防止神經網絡模型過擬合,從而提升模型的泛化性能,即提高神經網絡模型在實際應用中的預測準確度。
技術領域
本發明實施例涉及機器學習技術,尤其涉及一種神經網絡模型訓練方法、裝置、設備及存儲介質。
背景技術
隨著人工智能的發展,深度神經網絡以其強大的擬合能力在眾多領域取得了驚人的成績。為了得到更好的神經網絡泛化能力,現有的網絡結構變得越來越復雜,網絡的參數量也呈現爆炸式增長的趨勢,復雜的神經網絡將導致模型出現過擬合現象。即模型在訓練集上表現的很好,但是在測試集或實際應用過程中上表現一般,導致模型在實際應用中預測準確度較低。
發明內容
本發明提供一種神經網絡模型訓練方法、裝置、設備及存儲介質,能夠有效防止神經網絡模型出現過擬合的現象,提高神經網絡模型在實際應用中的預測準確度。
第一方面,本發明實施例提供了一種神經網絡模型訓練方法,包括:
獲取用于訓練神經網絡模型的一批第一訓練樣本和一批第二訓練樣本;
將所述第一訓練樣本輸入待訓練的神經網絡模型中進行處理,得到用于預測所述第一訓練樣本屬于各類別的概率向量;
將所述第二訓練樣本輸入待訓練的神經網絡模型中進行處理,得到由所述第二訓練樣本的特征向量組成的特征矩陣,所述第二訓練樣本的特征向量用于表征所述第二訓練樣本的屬性;
基于所述概率向量計算一批所述第一訓練樣本的分類損失;
基于所述特征矩陣的低秩約束計算一批所述第二訓練樣本的正則化損失;
基于所述分類損失和所述正則化損失更新所述神經網絡模型的參數。
第二方面,本發明實施例還提供了一種神經網絡模型訓練裝置,包括:
訓練樣本獲取模塊,用于獲取用于訓練神經網絡模型的一批第一訓練樣本和一批第二訓練樣本;
概率向量獲取模塊,用于將所述第一訓練樣本輸入待訓練的神經網絡模型中進行處理,得到用于預測所述第一訓練樣本屬于各類別的概率向量;
特征矩陣獲取模塊,用于將所述第二訓練樣本輸入待訓練的神經網絡模型中進行處理,得到由所述第二訓練樣本的特征向量組成的特征矩陣,所述第二訓練樣本的特征向量用于表征所述第二訓練樣本的屬性;
分類損失計算模塊,用于基于所述概率向量計算一批所述第一訓練樣本的分類損失;
正則化損失計算模塊,用于基于所述特征矩陣的低秩約束計算一批所述第二訓練樣本的正則化損失;
參數更新模塊,用于基于所述分類損失和所述正則化損失更新所述神經網絡模型的參數。
第三方面,本發明實施例還提供了一種計算機設備,包括:
一個或多個處理器;
存儲裝置,用于存儲一個或多個程序;
當所述一個或多個程序被所述一個或多個處理器執行,使得所述一個或多個處理器實現如本發明第一方面提供的神經網絡模型訓練方法。
該專利技術資料僅供研究查看技術是否侵權等信息,商用須獲得專利權人授權。該專利全部權利屬于廣州視源電子科技股份有限公司;廣州視源人工智能創新研究院有限公司,未經廣州視源電子科技股份有限公司;廣州視源人工智能創新研究院有限公司許可,擅自商用是侵權行為。如果您想購買此專利、獲得商業授權和技術合作,請聯系【客服】
本文鏈接:http://www.szxzyx.cn/pat/books/202110181147.5/2.html,轉載請聲明來源鉆瓜專利網。





