[發明專利]模型訓練方法、裝置、設備、存儲介質和程序產品在審
| 申請號: | 202011563834.5 | 申請日: | 2020-12-25 |
| 公開(公告)號: | CN112580732A | 公開(公告)日: | 2021-03-30 |
| 發明(設計)人: | 王龍飛 | 申請(專利權)人: | 北京百度網訊科技有限公司 |
| 主分類號: | G06K9/62 | 分類號: | G06K9/62;G06N3/04;G06N3/08 |
| 代理公司: | 北京品源專利代理有限公司 11332 | 代理人: | 孟金喆 |
| 地址: | 100085 北京市*** | 國省代碼: | 北京;11 |
| 權利要求書: | 查看更多 | 說明書: | 查看更多 |
| 摘要: | |||
| 搜索關鍵詞: | 模型 訓練 方法 裝置 設備 存儲 介質 程序 產品 | ||
1.一種模型訓練方法,包括:
在樣本集合中獲取本輪樣本集輸入至待訓練模型中,并根據待訓練模型的輸出結果,計算本輪訓練損失;
根據所述本輪訓練損失,確定本輪擾動項,并將所述本輪擾動項加入至所述本輪樣本集中,得到本輪對抗樣本集;
使用所述本輪樣本集和所述本輪對抗樣本集共同對待訓練模型進行訓練,得到本輪訓練模型;
將所述本輪訓練模型確定為新的待訓練模型后,返回在樣本集合中獲取本輪樣本集輸入至待訓練模型中的操作,響應于滿足結束訓練條件,獲取目標訓練模型。
2.根據權利要求1所述的方法,其中,所述樣本集合中的樣本包括輸入特征,以及與所述輸入特征對應的標注數據;
根據所述本輪訓練損失,確定本輪擾動項,包括:
根據所述本輪訓練損失,對所述輸入特征求梯度,得到梯度值;
對所述梯度值進行歸一化處理,得到所述本輪擾動項。
3.根據權利要求1所述的方法,其中,使用所述本輪樣本集和所述本輪對抗樣本集共同對待訓練模型進行訓練,得到本輪訓練模型,包括:
將所述本輪樣本集和所述本輪對抗樣本集共同輸入至所述待訓練模型中;
根據所述待訓練模型的輸出結果,計算總損失,所述總損失包括與所述本輪樣本集對應的原始訓練損失,以及與所述本輪對抗樣本集對應的對抗訓練損失;
基于所述總損失,通過梯度下降法調節所述待訓練模型的參數,以得到所述本輪訓練模型。
4.根據權利要求2所述的方法,其中,所述本輪擾動項的計算方式如下:
其中,η表示擾動項,x表示樣本的輸入特征,y表示與所述輸入特征對應的標注數據,θ表示模型參數,∈表示添加擾動的最大強度,f(x;θ)表示待訓練模型針對輸入特征x的輸出結果,L(f(x;θ),y)表示所述待訓練模型的本輪訓練損失,g表示基于所述本輪訓練損失,對所述輸入特征x的梯度值,||g||2表示梯度值的二范數。
5.根據權利要求3所述的方法,其中,所述總損失的計算方式如下:
其中,Losstotal表示總損失,表示與所述本輪對抗樣本集對應的對抗訓練損失,L(f(x;θ),y)表示與所述本輪樣本集對應的原始訓練損失,α表示所述本輪對抗樣本集貢獻占比參數。
6.根據權利要求2所述的方法,其中,所述樣本中的輸入特征為:與用戶行為數據匹配的多項詞向量,所述樣本中的標注數據為用戶畫像。
7.根據權利要求6所述的方法,在獲取目標訓練模型之后,還包括:
獲取待識別用戶的目標用戶行為數據,并提取與所述目標用戶的行為數據匹配的多項目標詞向量;
將各所述目標詞向量輸入至所述目標訓練模型中,并獲取所述目標訓練模型輸出的,所述待識別用戶的目標用戶畫像。
8.一種模型訓練裝置,包括:
損失計算模塊,用于在樣本集合中獲取本輪樣本集輸入至待訓練模型中,并根據待訓練模型的輸出結果,計算本輪訓練損失;
對抗樣本獲取模塊,用于根據所述本輪訓練損失,確定本輪擾動項,并將所述本輪擾動項加入至所述本輪樣本集中,得到本輪對抗樣本集;
模型訓練模塊,用于使用所述本輪樣本集和所述本輪對抗樣本集共同對待訓練模型進行訓練,得到本輪訓練模型;
目標模型獲取模塊,用于將所述本輪訓練模型確定為新的待訓練模型后,返回在樣本集合中獲取本輪樣本集輸入至待訓練模型中的操作,響應于滿足結束訓練條件,獲取目標訓練模型。
9.根據權利要求8所述的裝置,其中,所述樣本集合中的樣本包括輸入特征,以及與所述輸入特征對應的標注數據;
所述對抗樣本獲取模塊,包括:
梯度值計算單元,用于根據所述本輪訓練損失,對所述輸入特征求梯度,得到梯度值;
擾動項計算單元,用于對所述梯度值進行歸一化處理,得到所述本輪擾動項。
該專利技術資料僅供研究查看技術是否侵權等信息,商用須獲得專利權人授權。該專利全部權利屬于北京百度網訊科技有限公司,未經北京百度網訊科技有限公司許可,擅自商用是侵權行為。如果您想購買此專利、獲得商業授權和技術合作,請聯系【客服】
本文鏈接:http://www.szxzyx.cn/pat/books/202011563834.5/1.html,轉載請聲明來源鉆瓜專利網。





