[發明專利]基于PyTorch至ONNX的神經網絡格式轉換方法在審
| 申請號: | 202210715812.9 | 申請日: | 2022-06-22 |
| 公開(公告)號: | CN115310584A | 公開(公告)日: | 2022-11-08 |
| 發明(設計)人: | 田聰;王甜;劉濤;賀凱;張海瑞;劉昊庭;趙耕;楊佳;張梓輝;虞小龍;張鵬 | 申請(專利權)人: | 西安電子科技大學 |
| 主分類號: | G06N3/04 | 分類號: | G06N3/04;G06N3/08 |
| 代理公司: | 西安嘉思特知識產權代理事務所(普通合伙) 61230 | 代理人: | 王海棟 |
| 地址: | 710071*** | 國省代碼: | 陜西;61 |
| 權利要求書: | 查看更多 | 說明書: | 查看更多 |
| 摘要: | |||
| 搜索關鍵詞: | 基于 pytorch onnx 神經網絡 格式 轉換 方法 | ||
1.一種基于PyTorch至ONNX的神經網絡格式轉換方法,其特征在于,包括:
獲取PyTorch神經網絡模型,解析所述PyTorch神經網絡模型,獲取所述PyTorch神經網絡模型的各操作層的參數信息和前向傳播過程;
將所述PyTorch神經網絡模型的各操作層的參數信息和前向傳播過程轉換至ONNX模型;
序列化保存所述ONNX模型至ONNX文件,完成轉換過程。
2.根據權利要求1所述的基于PyTorch至ONNX的神經網絡格式轉換方法,其特征在于,所述PyTorch神經網絡模型的各操作層的參數信息和前向傳播過程的獲取方法包括:
獲取所述PyTorch神經網絡模型的訓練圖;
獲取所述訓練圖中各操作層的定義圖信息;
基于所述訓練圖中各操作層的所述定義圖信息,獲取所述訓練圖中各操作層的參數圖信息和輸入圖信息;
解析所述參數圖信息,獲取各操作層的參數信息;
解析所述定義圖信息和所述輸入圖信息,獲取所述PyTorch神經網絡模型的前向傳播過程。
3.根據權利要求2所述的基于PyTorch至ONNX的神經網絡格式轉換方法,其特征在于,所述獲取所述訓練圖中各操作層的定義圖信息的具體過程包括:
遍歷所述訓練圖中所有的圖信息;
判斷所述圖信息的類型是否以aten開頭;
如果是,則判斷所述圖信息為所述定義圖信息。
4.根據權利要求2所述的基于PyTorch至ONNX的神經網絡格式轉換方法,其特征在于,所述基于所述訓練圖中各操作層的所述定義圖信息,獲取所述訓練圖中各操作層的參數圖信息和輸入圖信息的具體過程包括:
基于所述訓練圖中各操作層的所述定義圖信息,通過inputs()函數讀取各操作層的所述參數圖信息和所述輸入圖信息,構建列表。
5.根據權利要求2所述的基于PyTorch至ONNX的神經網絡格式轉換方法,其特征在于,所述解析所述參數圖信息,獲取各操作層的參數信息的具體過程包括:
對每一類型的操作層設置對應的函數,提取每一類型的操作層的參數信息;
分析當前類型的操作層所包括的所述參數信息數量以及名稱,使用get_attr()函數提取并記錄。
6.根據權利要求2所述的基于PyTorch至ONNX的神經網絡格式轉換方法,其特征在于,所述解析所述定義圖信息和所述輸入圖信息,獲取所述PyTorch神經網絡模型的前向傳播過程的具體過程包括:
遍歷所述訓練圖中所有操作層的所述輸入圖信息,獲取所述輸入圖信息中保存的前驅操作層的定義圖信息;
將所述訓練圖中所有操作層的名稱以及節點對應的前驅操作層的定義圖信息以鍵值對的形式保存在同一字典中,構建所述PyTorch神經網絡模型的前向傳播字典。
7.根據權利要求1所述的基于PyTorch至ONNX的神經網絡格式轉換方法,其特征在于,所述將所述PyTorch神經網絡模型的各操作層的參數信息和前向傳播過程轉換至ONNX模型的具體過程包括:
獲取ONNX格式的輸入張量;
獲取ONNX格式的操作層列表;
獲取ONNX格式的輸出層形狀;
基于ONNX格式的所述輸出層形狀,獲取ONNX格式的輸出張量;
基于ONNX格式的所述輸入張量、ONNX格式的所述操作層列表和ONNX格式的所述輸出張量,生成ONNX格式的模型圖;
基于ONNX版本信息與ONNX格式的所述模型圖,生成ONNX格式的模型。
該專利技術資料僅供研究查看技術是否侵權等信息,商用須獲得專利權人授權。該專利全部權利屬于西安電子科技大學,未經西安電子科技大學許可,擅自商用是侵權行為。如果您想購買此專利、獲得商業授權和技術合作,請聯系【客服】
本文鏈接:http://www.szxzyx.cn/pat/books/202210715812.9/1.html,轉載請聲明來源鉆瓜專利網。





