[發明專利]一種神經網絡模型的訓練方法、裝置、存儲介質及設備在審
申請號: | 202010798134.8 | 申請日: | 2020-08-11 |
公開(公告)號: | CN111898754A | 公開(公告)日: | 2020-11-06 |
發明(設計)人: | 李鎮;張敏清 | 申請(專利權)人: | 香港中文大學(深圳) |
主分類號: | G06N3/08 | 分類號: | G06N3/08 |
代理公司: | 深圳尚業知識產權代理事務所(普通合伙) 44503 | 代理人: | 王利彬 |
地址: | 518000 廣東*** | 國省代碼: | 廣東;44 |
權利要求書: | 查看更多 | 說明書: | 查看更多 |
摘要: | |||
搜索關鍵詞: | 一種 神經網絡 模型 訓練 方法 裝置 存儲 介質 設備 | ||
本發明適用于模型訓練技術領域,提供了一種神經網絡模型的訓練方法、裝置及系統,所述方法包括:獲取原始數據集,并根據所述原始數據集訓練原始神經網絡模型;從所述原始神經網絡模型中識別出噪聲標簽;對所述噪聲標簽進行修改,并根據修改后的數據集訓練新神經網絡模型。本發明通過先以原始數據集訓練出原始神經網絡模型,并在原始神經網絡模型中識別出噪聲標簽,從而確定原始數據集中的錯誤標簽,在對錯誤標簽糾正之后,最終根據修改后的數據集訓練新神經網絡模型,由于直接從網絡模型中確定出錯誤標簽并對其進行糾正,準確性高,同樣具備了很好的可解釋性,使得最終訓練得到的新神經網絡模型具有較好的抗干擾效果。
技術領域
本發明屬于模型訓練技術領域,尤其涉及一種神經網絡模型的訓練方法、裝置、存儲介質及設備。
背景技術
深度學習技術已經在圖像處理領域中取得了巨大的成功,它們的成功都離不開神經網絡模型的訓練。在訓練神經網絡模型的過程中,數據及對應的標簽(金標準)是除網絡模型之外的最關鍵因素。
如果數據的標簽中存在一些噪聲,即錯誤的標簽,則會對網絡的訓練造成極大的負面影響,進而導致神經網絡模型的表現變差,即模型在標簽被噪聲污染時易受干擾。因此,在標簽中存在噪聲的情況下如何保證網絡模型的性能,使得模型的訓練具備一定抗干擾能力,是一個非常值得研究的技術。
現有技術當中,目前大多通過自監督的方式找出有噪聲的標簽,并在訓練模型計算的損失函數的時候,降低它們的權重,來達到抗干擾的效果,但這種方式存在不精確、缺乏解釋性等缺點,最終對模型的抗干擾效果提升有限。
發明內容
本發明實施例提供一種神經網絡模型的訓練方法、裝置、存儲介質及設備,旨在解決現有對噪聲標簽的處理方式不精確、導致對模型的抗干擾效果提升有限的技術問題。
本發明實施例是這樣實現的,一種神經網絡模型的訓練方法,所述方法包括:
獲取原始數據集,并根據所述原始數據集訓練原始神經網絡模型;
從所述原始神經網絡模型中識別出噪聲標簽;
對所述噪聲標簽進行修改,并根據修改后的數據集訓練新神經網絡模型。
進一步地,所述從所述原始神經網絡模型中識別出噪聲標簽的步驟包括。
利用置信度學習技術從所述原始神經網絡模型中識別出噪聲標簽。
進一步地,所述利用置信度學習技術從所述原始神經網絡模型中識別出噪聲標簽的步驟包括:
計算所述原始神經網絡模型對于每個類別的預測概率;
基于所述預測概率參數計算出噪聲標簽與真實標簽的混淆矩陣;
將所述混淆矩陣正則化,得到噪聲標簽與真實標簽的聯合分布;
基于所述混淆矩陣和/或所述聯合分布,計算出所述噪聲標簽。
進一步地,基于所述混淆矩陣和所述聯合分布,計算出所述噪聲標簽的步驟包括:
基于所述混淆矩陣和所述聯合分布的交集或并集,計算出所述噪聲標簽。
進一步地,所述預測概率包括平均預測概率和預測概率的中位數。
進一步地,根據修改后的數據集訓練新神經網絡模型的步驟包括:
根據所述修改后的數據集重新進行神經網絡模型訓練,以訓練得到所述新神經網絡模型;或者
根據所述修改后的數據集對所述原始神經網絡模型進行調整,以調整得到所述新神經網絡模型。
進一步地,所述對所述噪聲標簽進行修改的步驟包括:
該專利技術資料僅供研究查看技術是否侵權等信息,商用須獲得專利權人授權。該專利全部權利屬于香港中文大學(深圳),未經香港中文大學(深圳)許可,擅自商用是侵權行為。如果您想購買此專利、獲得商業授權和技術合作,請聯系【客服】
本文鏈接:http://www.szxzyx.cn/pat/books/202010798134.8/2.html,轉載請聲明來源鉆瓜專利網。