[發明專利]一種基于K鄰近結點算法和對比學習的文本分類方法在審
| 申請號: | 202110960433.1 | 申請日: | 2021-08-20 |
| 公開(公告)號: | CN113673242A | 公開(公告)日: | 2021-11-19 |
| 發明(設計)人: | 邱錫鵬;宋德敏;李林陽;傅家慶;楊非 | 申請(專利權)人: | 之江實驗室;復旦大學 |
| 主分類號: | G06F40/289 | 分類號: | G06F40/289;G06F40/211;G06K9/62;G06N3/08 |
| 代理公司: | 杭州浙科專利事務所(普通合伙) 33213 | 代理人: | 楊小凡 |
| 地址: | 310023 浙江省杭州市余*** | 國省代碼: | 浙江;33 |
| 權利要求書: | 查看更多 | 說明書: | 查看更多 |
| 摘要: | |||
| 搜索關鍵詞: | 一種 基于 鄰近 結點 算法 對比 學習 文本 分類 方法 | ||
本發明公開了一種基于K鄰近結點算法和對比學習的文本分類方法,該方法在訓練階段使用對比學習拉進類內距離,拉遠類間距離,并且結合交叉熵損失,輔助對比學習進行聯合訓練,在推理過程中,通過聯合訓練好的模型,結合最鄰近結點算法,進行聯合預測,計算待推斷文本的分類;本發明不僅能夠在文本分類的準確率上取得比目前業內使用的文本分類方式更高的結果,而且在模型的魯棒性上也取得了極大的提升。
技術領域
本發明涉及深度學習和自然語言處理,尤其是涉及一種基于K鄰近結點算法和對比學習的文本分類方法。
背景技術
文本分類任務是自然語言處理中的一類基礎任務,目前主流的文本分類方法是在大規模預訓練模型(如BERT)的基礎上,使用一個線性分類器進行分類。但是線性分類器往往不具備很好的魯棒性,容易被TextFooler或BertAttack這類對抗攻擊的方式所愚弄。
發明內容
為解決現有技術的不足,實現提高魯棒性的同時,提升模型分類準確率的目的,本發明采用如下的技術方案:
一種基于K鄰近結點算法和對比學習的文本分類方法,包括如下步驟:
S1,訓練過程中,通過構建句子向量表示k的正負樣本,進行對比學習,拉近類內間距,拉遠類間間距,對比學習的損失函數如下:
其中,M表示正樣本的數量,N表示負樣本的數量,q表示預訓練編碼器encoder_q輸出的句子的向量表示,k表示預訓練編碼器encoder_k輸出的句子向量表示,encoder_q與encoder_k相同,kj表示第j個正樣本k+,ki表示遍歷負樣本k-和kj的集合,exp(·)表示指數函數,τ為超參數;
結合交叉熵損失函數,進行聯合訓練,聯合損失函數如下:
L=λLec+(1-λ)Lsc
其中,λ表示調節交叉熵損失函數Lec和所述對比學習的損失函數Lsc之間的權重參數,yc表示q的類別,C表示文本分類的分類數,F(·)表示線性分類器;
反向傳播損失函數,更新encoder_q和線性分類器的參數;
聯合損失函數為交叉熵損失函數和有監督對比學習損失函數的加權和,通過對比學習的損失函數Lsc來輔助交叉熵損失函數訓練模型,使用對比學習訓練模型,使得模型在訓練過程中,能夠自動對樣本的embedding表示進行聚類,從而能夠達到更好的分類效果;
S2,通過訓練好的encoder_q和線性分類器,對文本進行分類。
進一步地,所述S2中,通過訓練好的encoder_q獲得待預測文本的句子向量表示q,使用聯合預測函數預測文本分類,聯合預測函數如下:
其中,S表示最終分類的概率值,表示超參數,Softmax(·)表示激活函數,F(q)表示訓練好的線性分類器,KNN(q)表示從隊列Q中選取在樣本空間中離q最近的K個訓練樣本,根據訓練樣本的分類標簽,用投票的方式給出KNN模型的概率值,通過概率值得到分類結果,在推斷樣本類別時,使用KNN和線性分類器聯合預測待預測樣本的分類,通過K鄰近結點算法,顯著提高了模型的魯棒性。
該專利技術資料僅供研究查看技術是否侵權等信息,商用須獲得專利權人授權。該專利全部權利屬于之江實驗室;復旦大學,未經之江實驗室;復旦大學許可,擅自商用是侵權行為。如果您想購買此專利、獲得商業授權和技術合作,請聯系【客服】
本文鏈接:http://www.szxzyx.cn/pat/books/202110960433.1/2.html,轉載請聲明來源鉆瓜專利網。
- 上一篇:驅動芯片、顯示模組、顯示面板與顯示面板的測試方法
- 下一篇:一種軟基處理設備





