軟硬件環境
-
ubuntu 18.04 64bit -
anaconda with 3.7 -
nvidia gtx 1070Ti -
cuda 10.1 -
pytorch 1.5
問題
在使用pytorch
深度學習框架訓練出來的模型文件,在另外的工程中使用,常常會碰到以下的錯誤html
File "/home/xugaoxiang/anaconda3/envs/torchTest/lib/python3.7/site-packages/torch/serialization.py", line 593, in load return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args) File "/home/xugaoxiang/anaconda3/envs/torchTest/lib/python3.7/site-packages/torch/serialization.py", line 773, in _legacy_load result = unpickler.load()ModuleNotFoundError: No module named 'models'
解決方法
其實這個問題,在pytorch
的官方文檔中就有提到,以下python
上面這種方法呢是推薦的作法,在執行torch.save
和torch.load
時,操做的都是模型的參數,這樣移植起來很是的方便
git
而下面的方法則是針對整個模型,在訓練模型的時候,會將本地的class
和目錄結構都寫入到模型中。不少開源項目在模型訓練完成後也是採用這樣的方法來保存,所以,當在你本身的項目中去使用這樣的模型時,每每就會遇到上面出現的問題,解決的方法就是在你的項目中保持原有項目的必要結構,如相應的class
和模塊github
參考資料
-
https://pytorch.org/docs/stable/notes/serialization.html -
https://discuss.pytorch.org/t/modulenotfounderror-no-module-named-network/71721/3 -
https://zhuanlan.zhihu.com/p/38056115 -
https://github.com/ultralytics/yolov5/issues/22
本文分享自微信公衆號 - 迷途小書童的Note(Dev_Club)。
若有侵權,請聯繫 support@oschina.cn 刪除。
本文參與「OSC源創計劃」,歡迎正在閱讀的你也加入,一塊兒分享。web