從人工智能鑑黃模型,嘗試TensorRT優化

隨着互聯網的快速發展,愈來愈多的圖片和視頻出如今網絡,特別是UCG產品,激發人們上傳圖片和視頻的熱情,好比微信天天上傳的圖片就高達10億多張。每一個人均可以上傳,這就帶來監管問題,若是沒有內容審覈,色情圖片和視頻就會氾濫。前不久,一貫以開放著稱的tumblr,就迫於壓力,開始限制人們分享色情圖片。更別提國內,內容審覈是UCG繞不過去的坎。還記得前幾年出現的職業鑑黃師這一職業麼?傳說百萬年薪,天天看黃片看得想吐,但最近又不多有人說起這一職業,這個應監管而生的職業,因人工智能的出現又快速消亡。(固然也不是徹底消亡,畢竟判斷是否色情圖片是一個主觀的事情,有些藝術和色情之間的邊界比較模糊,須要人工加以判斷)node

以前寫過一篇文章利用人工智能檢測色情圖片,也曾經嘗試過在瀏覽器中加入色情圖片過濾功能,但實驗下來,推理速度太慢(當時使用的Google Nexus 4作的測試,檢測一張圖片須要幾秒鐘),無法作實時過濾。最近在研究nvidia的Jetson Nano以及推理加速框架TensorRT,所以想嘗試一下,看可否應用一些加速方法,加速推理。python

雖然個人最終目標是應用到Jetson Nano,可是TensorRT其實適用於幾乎全部的Nvidia顯卡,爲了方便起見,我仍是先在PC端進行嘗試。沒有Nvidia顯卡?也沒有關係,能夠看看我前面發佈的兩篇文章:git

  1. 谷歌GPU雲計算平臺,免費又好用
  2. Google Colab上安裝TensorRT

open_nsfw

本文采用的深度學習模型是雅虎開源的深度學習色情圖片檢測模型open_nsfw,這裏的NSFW表明Not Suitable for Work,該項目基於caffe框架。因爲我主要研究的是Tensorflow,因此在網上找到該模型的Tensorflow實現版本,fork了一份,並添加了TensorRT框架的處理腳本,你可使用以下命令得到相關代碼:github

git clone https://github.com/mogoweb/tensorflow-open_nsfw.git
複製代碼

model.py 中,咱們能夠看到open_nsfw的模型定義,data/open_nsfw-weights.npy 是採用工具從yahoo open_nsfw的cafee權重轉換獲得的Tensorflow權重,這樣咱們無需訓練模型,直接用於推理過程。classify_nsfw.py 腳本可用於單張圖片的推理:web

python classify_nsfw.py -m data/open_nsfw-weights.npy test.jpg
複製代碼

注意:腳本提供了兩種解碼圖片文件的方式,一種是採用PIL.image、skimage進行圖片處理,也就是所謂的yahoo_image_loader,一種是採用tensorflow中的圖片處理函數進行處理。由於原始的open_nsfw模型是採用PIL.image、skimage進行預處理而訓練的,而不一樣的庫解碼出來的結果存在細微的差別,會影響最終結果,通常優選選擇yahoo_image_loader。固然,若是你打算本身訓練模型,那選擇哪一種圖片處理庫均可以。瀏覽器

tools 目錄下有一些腳本,能夠將模型導出爲frozen graphsaved model以及tflite等格式,這樣咱們能夠方便的在服務器端部署,還能夠應用到手機端。bash

opt是我編寫的採用TensorRT框架加速的代碼,在下面我將詳細說明。服務器

導出爲TensorRT模型

目前TensorRT做爲Tensorflow的一部分獲得Google官方支持,其包位於tensorflow.contrib.tensorrt,在代碼中加入:微信

import tensorflow.contrib.tensorrt as trt
複製代碼

就可使用TensorRT,由於有Google的支持,導出到TensorRT也就至關簡單:網絡

trt_graph = trt.create_inference_graph(
                input_graph_def=frozen_graph_def,
                outputs=[output_node_name],
                max_batch_size=1,
                max_workspace_size_bytes=1 << 25,
                precision_mode='FP16',
                minimum_segment_size=50
        )

        graph_io.write_graph(trt_graph, export_base_path, 'trt_' + graph_name, as_text=False)
複製代碼

其中:

  • input_graph_def 爲須要導出的Tensorflow模型圖定義
  • outputs 爲輸出節點名稱
  • max_batch_size 爲最大的batch size限制,由於GPU存在顯存限制,須要根據GPU memory大小決定,通常狀況能夠給8或者16
  • precision_mode 爲模型精度,有FP3二、FP16和INT8可選,精度越高,推理速度越慢,也要依GPU而定。

graph_io.write_graph 將圖寫入到文件,在後續的代碼中能夠加載之。

完整的代碼請參考 opt/export_trt.py 文件。

測試數據

由於一些政策法規的限制,並無公開數據集可提供下載,不過在github上有一些開源項目,提供腳本,從網絡上進行下載。我使用的是 github.com/alexkimxyz/… 這個開源項目中的腳本。這個項目提供drawings、hentai、neutral、porn、sexy四種類別圖片,能夠劃分爲訓練集和測試集,並檢查圖片是否有效(由於從網絡爬取,有些連接不必定能訪問到)。

注意這個圖片下載量很是大,須要注意別把硬盤撐滿。雖然這個數據量夠大(幾萬張),能夠自行進行模型訓練,但和yahoo訓練open_nsfw模型的圖片量相比,仍是小巫見大巫,聽說yahoo訓練這個模型用了幾百萬張的圖片。

推理速度對比

在opt目錄下,我針對兩種模型的加載和推理添加了兩個腳本,分別是 benchmark_classify_nsfw.pybenchmark_classify_trt.py,細心的同窗可能會發現,這兩個腳本幾乎如出一轍,是的,除了 benchmark_classify_trt.py 多了一行代碼:

import tensorflow.contrib.tensorrt as trt
複製代碼

加入這行import語句,告訴tensorflow使用TensorRT框架,不然的話,會出現以下錯誤:

tensorflow.python.framework.errors_impl.NotFoundError: Op type not registered 'TRTEngineOp' in binary running on alex-550-279cn. Make sure the Op and Kernel are registered in the binary running in this process. Note that if you are loading a saved graph which used ops from tf.contrib, accessing (e.g.) `tf.contrib.resampler` should be done before importing the graph, as contrib ops are lazily registered when the module is first accessed.
複製代碼

取2000張測試圖片進行測試,在個人GTX 960上,推理速度以下:

未優化模型: 53 s
使用TensorRT優化模型: 54 s
複製代碼

若是你下載更大的數據集,能夠多測試一些圖片,看看優化效果。

在Google Colab上,我放了一份Jupter Notebook,有興趣的同窗能夠藉助Google Colab嘗試一下,文件地址:colab.research.google.com/drive/1vH-G… ,固然你也能夠訪問我github上完整的腳本及Notebook:

github.com/mogoweb/ten…

點擊閱讀原文,能夠跳轉到該項目。

題外話:

微信公衆號流量主的門檻已經大大下降,我在公衆號文章底部開通了廣告,但願沒有影響你們的閱讀體驗。我一直很好奇,這種廣告會有人點擊麼,過一段也許我會獲得答案。

你還能夠讀:

  1. 利用人工智能檢測色情圖片
  2. 谷歌GPU雲計算平臺,免費又好用
  3. Google Colab上安裝TensorRT

image
相關文章
相關標籤/搜索