(一):進入GitHub下載模型--》下載地址python
由於咱們須要slim模塊,因此將包中的slim文件夾複製出來使用。git
(1):在slim中新建images文件夾存放圖片集github
(2):新建model文件夾用來放模型express
(3):在datasets文件夾中新建myimages.py文件apache
# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Provides data for the flowers dataset. The dataset scripts used to create the dataset can be found at: tensorflow/models/slim/datasets/download_and_convert_flowers.py """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import tensorflow as tf from datasets import dataset_utils slim = tf.contrib.slim _FILE_PATTERN = 'image_%s_*.tfrecord' SPLITS_TO_SIZES = {'train': 3500, 'test': 500} # 這裏根據本身的訓練集內容進行修改 _NUM_CLASSES = 5 _ITEMS_TO_DESCRIPTIONS = { 'image': 'A color image of varying size.', 'label': 'A single integer between 0 and 4', } def get_split(split_name, dataset_dir, file_pattern=None, reader=None): """Gets a dataset tuple with instructions for reading flowers. Args: split_name: A train/validation split name. dataset_dir: The base directory of the dataset sources. file_pattern: The file pattern to use when matching the dataset sources. It is assumed that the pattern contains a '%s' string so that the split name can be inserted. reader: The TensorFlow reader type. Returns: A `Dataset` namedtuple. Raises: ValueError: if `split_name` is not a valid train/validation split. """ if split_name not in SPLITS_TO_SIZES: raise ValueError('split name %s was not recognized.' % split_name) if not file_pattern: file_pattern = _FILE_PATTERN file_pattern = os.path.join(dataset_dir, file_pattern % split_name) # Allowing None in the signature so that dataset_factory can use the default. if reader is None: reader = tf.TFRecordReader keys_to_features = { 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'), 'image/class/label': tf.FixedLenFeature( [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)), } items_to_handlers = { 'image': slim.tfexample_decoder.Image(), 'label': slim.tfexample_decoder.Tensor('image/class/label'), } decoder = slim.tfexample_decoder.TFExampleDecoder( keys_to_features, items_to_handlers) labels_to_names = None if dataset_utils.has_labels(dataset_dir): labels_to_names = dataset_utils.read_label_file(dataset_dir) return slim.dataset.Dataset( data_sources=file_pattern, reader=reader, decoder=decoder, num_samples=SPLITS_TO_SIZES[split_name], items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, num_classes=_NUM_CLASSES, labels_to_names=labels_to_names)
(4):修改dataset_factory.pyapp
from datasets import myimages datasets_map = { 'cifar10': cifar10, 'flowers': flowers, 'imagenet': imagenet, 'mnist': mnist, 'myimages':myimages, # 這一句爲添加的內容 }
(二):對圖片進行處理,生成tfrecord格式的文件。less
import tensorflow as tf import os import random import math import sys #驗證集數量 _NUM_TEST = 500 #隨機種子 _RANDOM_SEED = 0 #數據塊數目 _NUM_SHARDS = 5 #數據集路徑 DATASET_DIR = "C:/Users/FELIX/Desktop/tensor_study/slim/images/" #標籤文件名字 LABELS_FILENAME = ''.join([DATASET_DIR,'labels.txt']) #定義tfrecord文件的路徑+名字 def _get_dataset_filename(dataset_dir, split_name, shard_id): output_filename = 'image_%s_%05d-of-%05d.tfrecord' % (split_name, shard_id, _NUM_SHARDS) return os.path.join(dataset_dir, output_filename) #判斷tfrecord文件是否存在 def _dataset_exists(dataset_dir): for split_name in ['train', 'test']: for shard_id in range(_NUM_SHARDS): #定義tfrecord文件的路徑+名字 output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id) if not tf.gfile.Exists(output_filename): return False return True #獲取全部文件以及分類 def _get_filenames_and_classes(dataset_dir): #數據目錄 directories = [] #分類名稱 class_names = [] for filename in os.listdir(dataset_dir): #合併文件路徑 path = os.path.join(dataset_dir, filename) #判斷該路徑是否爲目錄 if os.path.isdir(path): #加入數據目錄 directories.append(path) #加入類別名稱 class_names.append(filename) photo_filenames = [] #循環每一個分類的文件夾 for directory in directories: for filename in os.listdir(directory): path = os.path.join(directory, filename) #把圖片加入圖片列表 photo_filenames.append(path) return photo_filenames, class_names def int64_feature(values): if not isinstance(values, (tuple, list)): values = [values] return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) def bytes_feature(values): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) def image_to_tfexample(image_data, image_format, class_id): #Abstract base class for protocol messages. return tf.train.Example(features=tf.train.Features(feature={ 'image/encoded': bytes_feature(image_data), 'image/format': bytes_feature(image_format), 'image/class/label': int64_feature(class_id), })) def write_label_file(labels_to_class_names, dataset_dir,filename=LABELS_FILENAME): labels_filename = os.path.join(dataset_dir, filename) with tf.gfile.Open(labels_filename, 'w') as f: for label in labels_to_class_names: class_name = labels_to_class_names[label] f.write('%d:%s\n' % (label, class_name)) #把數據轉爲TFRecord格式 def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir): assert split_name in ['train', 'test'] #計算每一個數據塊有多少數據 num_per_shard = int(len(filenames) / _NUM_SHARDS) with tf.Graph().as_default(): with tf.Session() as sess: for shard_id in range(_NUM_SHARDS): #定義tfrecord文件的路徑+名字 output_filename = _get_dataset_filename(dataset_dir, split_name, shard_id) with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer: #每個數據塊開始的位置 start_ndx = shard_id * num_per_shard #每個數據塊最後的位置 end_ndx = min((shard_id+1) * num_per_shard, len(filenames)) for i in range(start_ndx, end_ndx): try: sys.stdout.write('\r>> Converting image %d/%d shard %d' % (i+1, len(filenames), shard_id)) sys.stdout.flush() #讀取圖片 image_data = tf.gfile.FastGFile(filenames[i], 'rb').read() # 這裏必定要rb不然會出現編碼錯誤 #得到圖片的類別名稱 class_name = os.path.basename(os.path.dirname(filenames[i])) #找到類別名稱對應的id class_id = class_names_to_ids[class_name] #生成tfrecord文件 example = image_to_tfexample(image_data, b'jpg', class_id) tfrecord_writer.write(example.SerializeToString()) except IOError as e: print("Could not read:",filenames[i]) print("Error:",e) print("Skip it\n") sys.stdout.write('\n') sys.stdout.flush() if __name__ == '__main__': #判斷tfrecord文件是否存在 if _dataset_exists(DATASET_DIR): print('tfcecord文件已存在') else: #得到全部圖片以及分類 photo_filenames, class_names = _get_filenames_and_classes(DATASET_DIR) #把分類轉爲字典格式,相似於{'house': 3, 'flower': 1, 'plane': 4, 'guitar': 2, 'animal': 0} class_names_to_ids = dict(zip(class_names, range(len(class_names)))) #把數據切分爲訓練集和測試集 random.seed(_RANDOM_SEED) random.shuffle(photo_filenames) training_filenames = photo_filenames[_NUM_TEST:] testing_filenames = photo_filenames[:_NUM_TEST] #數據轉換 _convert_dataset('train', training_filenames, class_names_to_ids, DATASET_DIR) _convert_dataset('test', testing_filenames, class_names_to_ids, DATASET_DIR) #輸出labels文件 labels_to_class_names = dict(zip(range(len(class_names)), class_names)) write_label_file(labels_to_class_names, DATASET_DIR)
(三):新建批處理文件,開始訓練模型dom
python C:/Users/FELIX/Desktop/tensor_study/slim/train_image_classifier.py ^
--train_dir=C:/Users/FELIX/Desktop/tensor_study/slim/model ^
--dataset_name=myimages ^
--dataset_split_name=train ^
--dataset_dir=C:/Users/FELIX/Desktop/tensor_study/slim/images ^
--batch_size=10 ^
--max_number_of_steps=10000 ^
--model_name=inception_v3 ^
pause
註釋:
第一行表示運行訓練文件,路徑爲全路徑
第二行表示模型存放位置
第三行爲建立的myimages文件名
第四行爲使用的訓練集
第五行爲數據集所在的位置
第六行爲批次大小,默認爲32,看我的GPU,我用10
第七行爲訓練次數,默認無限次
第八行爲使用模型名稱