C# WPF調用python-tensorflow2深度學習模型
python在研究深度學習人工智能領域十分強大,但在工業項目開發中仍經常使用C#和C++來作軟件,C++有Caffe深度學習框架,但C#尚且沒有成熟的深度學習框架(有個tensroflow.net尚在開發中,有興趣能夠去研究研究)。如今實驗室項目開發又要用C#,通過實踐最終決定在C#端利用OpencvSharp4的DNN模塊加載python端tensorflow2訓練的模型進行預測,其速度還能夠。
一 環境介紹
python:Python3.7 tensorflow2.1前端
c#: vs2017 .net framework 4.6.1python
二 tensorflow模型的訓練和生成
1 加載數據訓練模型
1.1 數據集採用貓狗二分類數據。
數據集網盤連接:連接:https://pan.baidu.com/s/15LR7-tgvglzwW9n4eFsFgg
提取碼:iz64
1.2 建立圖片數據輸入管道
代碼實現:
express
import tensorflow as tf import numpy as np import matplotlib.pyplot as plt import glob gpu = tf.config.experimental.list_physical_devices(device_type='GPU') tf.config.experimental.set_virtual_device_configuration( gpu[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=5120)]) image_path = glob.glob('./datasets/dc/train/*.jpg') image_label = [int(path.split('\\')[1].split('.')[0]=='cat') for path in image_path] def get_image_data(path, label): image = tf.io.read_file(path) image = tf.image.decode_jpeg(image, channels=3) image = tf.image.resize(image, (224, 224)) image = tf.cast(image, tf.float32)/255 label = tf.reshape(label, [1]) return image, label dataset = tf.data.Dataset.from_tensor_slices((image_path, image_label)) dataset = dataset.map(get_image_data) train_count = int(len(image_path)*0.8) test_count = len(image_path)-train_count train_dataset = dataset.skip(test_count) test_dataset = dataset.take(test_count) train_dataset = train_dataset.shuffle(len(image_path)).repeat().batch(BATCH_SIZE) test_dataset = test_dataset.batch(BATCH_SIZE)
1.3 搭建並訓練模型最後保存模型及參數,保存的格式的.h5文件,最終的準確度基本在99%以上。
代碼實現:
c#
MobileNet = tf.keras.applications.MobileNetV2(include_top=False, weights='imagenet', input_shape=(224, 224, 3)) model = tf.keras.Sequential() model.add(MobileNet) model.add(tf.keras.layers.GlobalAveragePooling2D()) model.add(tf.keras.layers.Dense(256, activation='relu')) model.add(tf.keras.layers.Dense(1, activation='sigmoid')) model.compile(optimizer='adam', loss=tf.keras.losses.binary_crossentropy, metrics=['acc']) model.fit(train_dataset, epochs=10, steps_per_epoch=train_count//BATCH_SIZE, validation_data=test_dataset, validation_steps=test_count//BATCH_SIZE) model.save('./model_h5/mobilenet.h5')
2 h5文件轉pb
Opencv的DNN模塊接收tensorflow模型文件爲pb文件,先將h5文件轉換成pb文件,在tensorflow2.0端完成文件類型的轉換。
轉換代碼:
後端
#參數1爲h5文件的路徑,參數2爲要將pb文件保存到那個文件夾的路徑,最後一個參數爲pb文件的名稱 def convert_h5to_pb(h5_path, pb_path, pb_name): model = tf.keras.models.load_model(h5_path, compile=False) model.summary() full_model = tf.function(lambda Input: model(Input)) full_model = full_model.get_concrete_function(tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype)) frozen_func = convert_variables_to_constants_v2(full_model) frozen_func.graph.as_graph_def() layers = [op.name for op in frozen_func.graph.get_operations()] print("-" * 50) print("Frozen model layers: ") for layer in layers: print(layer) print("-" * 50) print("Frozen model inputs: ") print(frozen_func.inputs) print("Frozen model outputs: ") print(frozen_func.outputs) tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=pb_path, name=pb_name, as_text=False
二 C#加載模型並預測
1 vs2017環境搭建
在項目屬性中設置平臺目標爲x64,
目標框架選擇.net framework 4.6.1,沒有該框架的可去官網下載安裝。
進入NuGet程序包管理界面,搜索並下載以下三個包,有可能因爲網絡問題沒法下載,可根據提示網站進入下載。
網絡
2 調用模型
1 xml前端界面app
<Window xmlns="http://schemas.microsoft.com/winfx/2006/xaml/presentation" xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml" xmlns:d="http://schemas.microsoft.com/expression/blend/2008" xmlns:mc="http://schemas.openxmlformats.org/markup-compatibility/2006" xmlns:telerik="http://schemas.telerik.com/2008/xaml/presentation" x:Class="WpfApp1.MainWindow" mc:Ignorable="d" Title="貓狗分類" Height="300" Width="500" WindowStartupLocation="CenterScreen"> <Grid> <Grid.ColumnDefinitions> <ColumnDefinition Width="200"/> <ColumnDefinition/> </Grid.ColumnDefinitions> <Grid Grid.Column="0"> <Grid.RowDefinitions> <RowDefinition Height="1*"/> <RowDefinition Height="3*"/> </Grid.RowDefinitions> <telerik:RadButton x:Name="read_image" Content="讀取圖片" Click="Read_image_Click" Margin="70,15,50,15"/> <Grid Grid.Row="1"> <Grid.ColumnDefinitions> <ColumnDefinition Width="70"/> <ColumnDefinition/> </Grid.ColumnDefinitions> <Grid.RowDefinitions> <RowDefinition/> <RowDefinition/> <RowDefinition/> <RowDefinition Height="20"/> </Grid.RowDefinitions> <Label Content="得分:" HorizontalAlignment="Center" VerticalAlignment="Center"/> <TextBox x:Name="score" HorizontalAlignment="Left" VerticalAlignment="Center" Grid.Row="0" Grid.Column="1" Width="120"/> <Label Content="類別:" HorizontalAlignment="Center" VerticalAlignment="Center" Grid.Row="1"/> <TextBox x:Name="classes" HorizontalAlignment="Left" VerticalAlignment="Center" Grid.Row="1" Grid.Column="1" Width="120"/> <Label Content="時間:" HorizontalAlignment="Center" VerticalAlignment="Center" Grid.Row="2"/> <TextBox x:Name="time" HorizontalAlignment="Left" VerticalAlignment="Center" Grid.Row="2" Grid.Column="1" Width="120"/> </Grid> </Grid> <Border BorderBrush="Black" BorderThickness="1" Grid.Column="1" HorizontalAlignment="Center" Height="214" VerticalAlignment="Center" Width="265"> <Image x:Name="img"/> </Border> </Grid> </Window>
2 C#後端實現預測框架
//引入OpencvSharp和Dnn模塊 using System; using System.Windows; using System.Windows.Media.Imaging; using OpenCvSharp.Dnn; using OpenCvSharp; using Microsoft.Win32; namespace WpfApp1 { /// <summary> /// MainWindow.xaml 的交互邏輯 /// </summary> public partial class MainWindow : System.Windows.Window { public MainWindow() { InitializeComponent(); } public void Dnn_Classification(Mat image) { String model_path = ".//mobilenet.pb";//模型路徑 Net net = CvDnn.ReadNetFromTensorflow(model_path);//加載模型 if (net.Empty()) { MessageBox.Show("pd文件錯誤"); return; } Mat input_image = CvDnn.BlobFromImage(image, 1 / 255.0, new OpenCvSharp.Size(224, 224)); //圖片歸一化和resize net.SetInput(input_image); Mat result = net.Forward();//載入圖片並前向計算 float result_score = result.Get<float>(0, 0);//得到計算結果 score.Text = result_score.ToString(); if (result_score >= 0.5) { classes.Text = "Cat"; } else { classes.Text = "Dog"; } } private void Read_image_Click(object sender, RoutedEventArgs e) { OpenFileDialog ofd = new OpenFileDialog(); ofd.InitialDirectory = @"C:\Users\LemonQiu\Desktop"; ofd.Filter = "JPG圖片|*.jpg|PNG圖片|*.png"; if (ofd.ShowDialog() == true) { img.Source = new BitmapImage(new Uri(ofd.FileName)); Mat image = Cv2.ImRead(ofd.FileName); System.Diagnostics.Stopwatch watch = new System.Diagnostics.Stopwatch(); watch.Start(); Dnn_Classification(image); watch.Stop(); TimeSpan timespan = watch.Elapsed; time.Text = (timespan.TotalMilliseconds).ToString() + "ms"; } else { MessageBox.Show("沒有選擇圖片"); } } } }
三 最終效果
檢測時間基本穩定在100ms每張。
學習