tensorflow slim

tensorflow的slim包集成了不少tensorflow中的高效函數,也其實就是原版tensorflow的瘦身(slim)   API1.4python

1. 導入git

1 import tensorflow.contrib.slim as slim

2. 定義express

 1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
 2 #
 3 # Licensed under the Apache License, Version 2.0 (the "License");
 4 # you may not use this file except in compliance with the License.
 5 # You may obtain a copy of the License at
 6 #
 7 #     http://www.apache.org/licenses/LICENSE-2.0
 8 #
 9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 # ==============================================================================
15 """Slim is an interface to contrib functions, examples and models.
16 
17 TODO(nsilberman): flesh out documentation.
18 """
19 
20 from __future__ import absolute_import
21 from __future__ import division
22 from __future__ import print_function
23 
24 # pylint: disable=unused-import,line-too-long,g-importing-member,wildcard-import
25 # TODO(jart): Delete non-slim imports
26 from tensorflow.contrib import losses
27 from tensorflow.contrib import metrics
28 from tensorflow.contrib.framework.python.ops.arg_scope import *
29 from tensorflow.contrib.framework.python.ops.variables import *
30 from tensorflow.contrib.layers.python.layers import *
31 from tensorflow.contrib.layers.python.layers.initializers import *
32 from tensorflow.contrib.layers.python.layers.regularizers import *
33 from tensorflow.contrib.slim.python.slim import evaluation
34 from tensorflow.contrib.slim.python.slim import learning
35 from tensorflow.contrib.slim.python.slim import model_analyzer
36 from tensorflow.contrib.slim.python.slim import queues
37 from tensorflow.contrib.slim.python.slim import summaries
38 from tensorflow.contrib.slim.python.slim.data import data_decoder
39 from tensorflow.contrib.slim.python.slim.data import data_provider
40 from tensorflow.contrib.slim.python.slim.data import dataset
41 from tensorflow.contrib.slim.python.slim.data import dataset_data_provider
42 from tensorflow.contrib.slim.python.slim.data import parallel_reader
43 from tensorflow.contrib.slim.python.slim.data import prefetch_queue
44 from tensorflow.contrib.slim.python.slim.data import tfexample_decoder
45 from tensorflow.python.util.all_util import make_all
46 # pylint: enable=unused-import,line-too-long,g-importing-member,wildcard-import
47 
48 __all__ = make_all(__name__)

只是其餘函數的一個簡化引用,舉例:apache

 1 from __future__ import absolute_import
 2 from __future__ import division
 3 from __future__ import print_function
 4 
 5 # pylint: disable=unused-import,wildcard-import
 6 from tensorflow.contrib.framework.python.framework import *
 7 from tensorflow.contrib.framework.python.ops import *
 8 # pylint: enable=unused-import,wildcard-import
 9 
10 from tensorflow.python.framework.ops import prepend_name_scope
11 from tensorflow.python.framework.ops import strip_name_scope
12 
13 from tensorflow.python.util.all_util import remove_undocumented
14 
15 _allowed_symbols = ['nest']
16 
17 remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)

 

1 import tf.contrib.framework.arg_scope as arg_scope
2 import tf.contrib.slim.arg_scope as arg_scope

上面兩種導入方式實際上是對同一個文件的不一樣導入方式app

3. slim的簡化使用例子,和原版對比less

https://www.jianshu.com/p/e2ada4ddae9aide

 4. 函數

slim中函數彙總以下:oop

1 import tensorflow as tf
2 import tf.contrib.slim as slim

 1).  slim.lossesfetch

 1 __all__ = ["absolute_difference",     @deprecated("2016-12-30", "Use tf.losses.absolute_difference instead.")
 2            "add_loss",           @deprecated("2016-12-30", "Use tf.losses.add_loss instead.")
 3            "cosine_distance",       @deprecated("2016-12-30", "Use tf.losses.cosine_distance instead.")
 4            "compute_weighted_loss",     @deprecated("2016-12-30", "Use tf.losses.compute_weighted_loss instead.")
 5            "get_losses",          @deprecated("2016-12-30", "Use tf.losses.get_losses instead.")
 6            "get_regularization_losses", @deprecated("2016-12-30", "Use tf.losses.get_regularization_losses instead.")
 7            "get_total_loss",        @deprecated("2016-12-30", "Use tf.losses.get_total_loss instead.")
 8            "hinge_loss",  @deprecated("2016-12-30","Use tf.losses.hinge_loss instead. Note that the order of the logits and labels arguments has been changed, and to stay unweighted, reduction=Reduction.NONE")
 9            "log_loss",    @deprecated("2016-12-30",Use tf.losses.log_loss instead. Note that the order of the predictions and labels arguments has been changed.")
10            "mean_pairwise_squared_error",  @deprecated("2016-12-30","Use tf.losses.mean_pairwise_squared_error instead. Note that the order of the predictions and labels arguments has been changed.")
11            "mean_squared_error",    @deprecated("2016-12-30", "Use tf.losses.mean_squared_error instead.")
12            "sigmoid_cross_entropy",  @deprecated("2016-12-30","Use tf.losses.sigmoid_cross_entropy instead. Note that the order of the predictions and labels arguments has been changed.")
13            "softmax_cross_entropy",  @deprecated("2016-12-30","Use tf.losses.softmax_cross_entropy instead. Note that the order of the logits and labels arguments has been changed.")
14            "sparse_softmax_cross_entropy"] @deprecated("2016-12-30","Use tf.losses.sparse_softmax_cross_entropy instead. Note that the order of the logits and labels arguments has been changed.")

 

 就是說slim基本上已經不支持losses操做了,所有改爲 tf.losses的操做

爲何?   多是想使分類更加具備層次吧,loss就該處理loss,  而不是有什麼cosine_distance

 

1 @@add_loss
2 @@get_losses
3 @@get_regularization_loss
4 @@get_regularization_losses
5 @@get_total_loss

 其餘的tf.losses函數在另外兩個文件中有定義

 2) slim.metrics

 

 1 @@streaming_accuracy
 2 @@streaming_mean
 3 @@streaming_recall
 4 @@streaming_recall_at_thresholds
 5 @@streaming_precision
 6 @@streaming_precision_at_thresholds
 7 @@streaming_auc
 8 @@streaming_curve_points
 9 @@streaming_recall_at_k
10 @@streaming_mean_absolute_error
11 @@streaming_mean_iou
12 @@streaming_mean_relative_error
13 @@streaming_mean_squared_error
14 @@streaming_mean_tensor
15 @@streaming_root_mean_squared_error
16 @@streaming_covariance
17 @@streaming_pearson_correlation
18 @@streaming_mean_cosine_distance
19 @@streaming_percentage_less
20 @@streaming_sensitivity_at_specificity
21 @@streaming_sparse_average_precision_at_k
22 @@streaming_sparse_average_precision_at_top_k
23 @@streaming_sparse_precision_at_k
24 @@streaming_sparse_precision_at_top_k
25 @@streaming_sparse_recall_at_k
26 @@streaming_specificity_at_sensitivity
27 @@streaming_concat
28 @@streaming_false_negatives
29 @@streaming_false_negatives_at_thresholds
30 @@streaming_false_positives
31 @@streaming_false_positives_at_thresholds
32 @@streaming_true_negatives
33 @@streaming_true_negatives_at_thresholds
34 @@streaming_true_positives
35 @@streaming_true_positives_at_thresholds
36 @@sparse_recall_at_top_k
37 @@auc_using_histogram
38 @@accuracy
39 @@aggregate_metrics
40 @@aggregate_metric_map
41 @@confusion_matrix
42 @@set_difference
43 @@set_intersection
44 @@set_size
45 @@set_union

 

 3) slim.arg_scope

1 __all__ = ['arg_scope',
2            'add_arg_scope',
3            'has_arg_scope',
4            'arg_scoped_arguments']

 

 4) slim.variables

 1 __all__ = ['add_model_variable',
 2            'assert_global_step',      @deprecated(None, "Please switch to tf.train.assert_global_step")
 3            'assert_or_get_global_step',
 4            'assign_from_checkpoint',
 5            'assign_from_checkpoint_fn',
 6            'assign_from_values',
 7            'assign_from_values_fn',
 8            'create_global_step',      @deprecated(None, "Please switch to tf.train.create_global_step")
 9            'filter_variables',
10            'get_global_step',        @deprecated(None, "Please switch to tf.train.get_global_step")
11            'get_or_create_global_step',  @deprecated(None, "Please switch to tf.train.get_or_create_global_step")
12            'get_local_variables',
13            'get_model_variables',
14            'get_trainable_variables',
15            'get_unique_variable',
16            'get_variables_by_name',
17            'get_variables_by_suffix',
18            'get_variable_full_name',
19            'get_variables_to_restore',
20            'get_variables',
21            'local_variable',
22            'model_variable',
23            'variable',
24            'VariableDeviceChooser',
25            'zero_initializer']

 主要是把slim.variables中關於global_step的函數所有移除了,改成tf.train中global_step的函數

 5) from tensorflow.contrib.layers.python.layers import *

 1 from tensorflow.contrib.layers.python.layers.embedding_ops import *
 2 from tensorflow.contrib.layers.python.layers.encoders import *
 3 from tensorflow.contrib.layers.python.layers.feature_column import *
 4 from tensorflow.contrib.layers.python.layers.feature_column_ops import *
 5 from tensorflow.contrib.layers.python.layers.initializers import *
 6 from tensorflow.contrib.layers.python.layers.layers import *
 7 from tensorflow.contrib.layers.python.layers.normalization import *
 8 from tensorflow.contrib.layers.python.layers.optimizers import *
 9 from tensorflow.contrib.layers.python.layers.regularizers import *
10 from tensorflow.contrib.layers.python.layers.summaries import *
11 from tensorflow.contrib.layers.python.layers.target_column import *
12 from tensorflow.contrib.layers.python.ops.bucketization_op import *
13 from tensorflow.contrib.layers.python.ops.sparse_feature_cross_op import *

 

embedding_ops.py

1 __all__ = [
2     "safe_embedding_lookup_sparse", "scattered_embedding_lookup",
3     "scattered_embedding_lookup_sparse", "embedding_lookup_unique",
4     "embedding_lookup_sparse_with_distributed_aggregation"
5 ]

encoder.py

1__all__ = ['bow_encoder', 'embed_sequence']

 initializer.py

1 __all__ = ['xavier_initializer', 'xavier_initializer_conv2d',
2            'variance_scaling_initializer']

 

 layers.py

 1 __all__ = ['avg_pool2d',
 2            'avg_pool3d',
 3            'batch_norm',
 4            'bias_add',
 5            'conv2d',
 6            'conv3d',
 7            'conv2d_in_plane',
 8            'conv2d_transpose',
 9            'conv3d_transpose',
10            'convolution',
11            'convolution2d',
12            'convolution2d_in_plane',
13            'convolution2d_transpose',
14            'convolution3d',
15            'convolution3d_transpose',
16            'dropout',
17            'elu',
18            'flatten',
19            'fully_connected',
20            'GDN',
21            'gdn',
22            'layer_norm',
23            'linear',
24            'pool',
25            'max_pool2d',
26            'max_pool3d',
27            'one_hot_encoding',
28            'relu',
29            'relu6',
30            'repeat',
31            'scale_gradient',
32            'separable_conv2d',
33            'separable_convolution2d',
34            'softmax',
35            'spatial_softmax',
36            'stack',
37            'unit_norm',
38            'legacy_fully_connected',
39            'legacy_linear',
40            'legacy_relu',
41            'maxout']

 

 normalization.py

1 __all__ = [
2     'instance_norm',
3 ]

 

optimizer.py

1 __all__ = [
2         'optimize_loss',
3         'adaptive_clipping_fn',
4         'gradient_clipping',
5 ]

 

 regularizers.py

1 __all__ = ['l1_regularizer',
2            'l2_regularizer', 3 'l1_l2_regularizer', 4 'sum_regularizer', 5 'apply_regularization']

 summaries.py

 1 __all__ = [
 2     'summarize_tensor',
 3     'summarize_activation',
 4     'summarize_tensors',
 5     'summarize_collection',
 6     'summarize_variables',
 7     'summarize_weights',
 8     'summarize_biases',
 9     'summarize_activations',
10 ]

 

 target_column.py

1 @deprecated("2016-11-12", "This file will be removed after the deprecation date. Please switch to third_party/tensorflow/contrib/learn/python/learn/estimators/head.py")

 

utils.py

1 __all__ = ['collect_named_outputs',
2            'constant_value',
3            'static_cond',
4            'smart_cond',
5            'get_variable_collections',
6            'two_element_tuple',
7            'n_positive_integers',
8            'channel_dimension',
9            'last_dimension']

 

feature_column.py  feature_column_ops.py

6) 

slim.evaluation

1 __all__ = [
2     'evaluate_once',
3     'evaluation_loop',
4     'wait_for_new_checkpoint',
5     'checkpoints_iterator',
6 ]

 slim.learning

1 __all__ = [
2     'add_gradients_summaries', 'clip_gradient_norms', 'multiply_gradients',
3     'create_train_op', 'train_step', 'train'
4 ]

 slim.model_analyzer

1 __all__ = [
2         'tensor_description',
3         'analyze_ops',
4         'analyze_vars',
5 ]

 slim.queues

1 __all__ = [
2     'NestedQueueRunnerError',
3     'QueueRunners',
4 ]

 slim.summaries

 1 __all__ = [
 2         'add_histogram_summary',
 3         'add_image_summary',
 4         'add_scalar_summary',
 5         'add_zero_fraction_summary',
 6         'add_histogram_summaries',
 7         'add_image_summaries',
 8         'add_scalar_summaries',
 9         'add_zero_fraction_summaries',
10 ]

 

7) 

slim.data_decoder

slim.data_provider

class DataProvider(object):
    def get(self,items):
    def list_items(self):
    def num_samples(self):

 

slim.dataset

1 class Dataset(object):
2   """Represents a Dataset specification."""
3 
4   def __init__(self, data_sources, reader, decoder, num_samples,
5                items_to_descriptions, **kwargs):

 

slim.dataset_data_provider

slim.parallel_reader

slim.prefetch_queue

slim.tfexample_decoder

8) 

1 from tensorflow.contrib.slim.python.slim.nets import alexnet
2 from tensorflow.contrib.slim.python.slim.nets import inception
3 from tensorflow.contrib.slim.python.slim.nets import overfeat
4 from tensorflow.contrib.slim.python.slim.nets import resnet_utils
5 from tensorflow.contrib.slim.python.slim.nets import resnet_v1
6 from tensorflow.contrib.slim.python.slim.nets import resnet_v2
7 from tensorflow.contrib.slim.python.slim.nets import vgg
8 from tensorflow.python.util.all_util import make_all

 

 1 alexnet.py
 2 inception.py
 3 inception_v1.py
 4 inception_v2.py
 5 inception_v3.py
 6 overfeat.py
 7 resnet_utils.py
 8 resnet_v1.py
 9 resnet_v2.py
10 vgg.py

 

9) slim.make_all

 

能夠把全部文件中__all__的方法返回

相關文章
相關標籤/搜索