TensorFlow-Slim image classification model library
TF-slim is a new lightweight high-level API of TensorFlow (tensorflow.contrib.slim
) for defining, training and evaluating complex models. This directory contains code for training and evaluating several widely used Convolutional Neural Network (CNN) image classification models using TF-slim. It contains scripts that will allow you to train models from scratch or fine-tune them from pre-trained network weights. It also contains code for downloading standard image datasets, converting them to TensorFlow's native TFRecord format and reading them in using TF-Slim's data reading and queueing utilities. You can easily train any model on any of these datasets, as we demonstrate below. We've also included a jupyter notebook, which provides working examples of how to use TF-Slim for image classification. For developing or modifying your own models, see also the main TF-Slim page.
TF-slim 是TensorFlow(tensorflow.contrib.slim
)的一个新的轻量级高级API,用于定义,训练和评估复杂模型。该目录包含用于训练和评估使用TF-slim的几种广泛使用的卷积神经网络(CNN)图像分类模型的代码。它包含的脚本可以让您从头开始训练模型,或者根据预先训练好的网络权重对其进行微调。它还包含下载标准图像数据集的代码,将它们转换为TensorFlow的原生TFRecord格式,并使用TF-Slim的数据读取和排队实用程序读取它们。您可以轻松地在任何这些数据集上训练任何模型,如下所示。我们还包括一个 jupyter笔记本,它提供了如何使用TF-Slim进行图像分类的实例。要开发或修改自己的模型,请参阅主TF-Slim页面。
TF-Slim tf.contrib.slim
通过TensorFlow 1.0提供。要测试您的安装是否正常,请执行以下命令; 它应该运行而不会产生任何错误。
python -c "import tensorflow.contrib.slim as slim; eval = slim.evaluation.evaluate_once" # python -c 表示在命令行的执行python指令
要使用TF-Slim进行图像分类,还必须安装TF-Slim图像模型库,该库不是核心TF库的一部分。要做到这一点,请按如下方式检查张量 流/模型库:
git clone https://github.com/tensorflow/models/
五、Fine-tuning a model from an existing checkpoint(从现有检查点微调模型)
Rather than training from scratch, we'll often want to start from a pre-trained model and fine-tune it. To indicate a checkpoint from which to fine-tune, we'll call training with the --checkpoint_path
flag and assign it an absolute path to a checkpoint file.
When fine-tuning a model, we need to be careful about restoring checkpoint weights. In particular, when we fine-tune a model on a new task with a different number of output labels, we wont be able restore the final logits (classifier) layer. For this, we'll use the --checkpoint_exclude_scopes
flag. This flag hinders certain variables from being loaded. When fine-tuning on a classification task using a different number of classes than the trained model, the new model will have a final 'logits' layer whose dimensions differ from the pre-trained model. For example, if fine-tuning an ImageNet-trained model on Flowers, the pre-trained logits layer will have dimensions [2048 x 1001]
but our new logits layer will have dimensions [2048 x 5]
. Consequently, this flag indicates to TF-Slim to avoid loading these weights from the checkpoint.
标志。此标志阻止某些变量被加载。当使用与训练模型不同数量的分类任务进行微调时,新模型将具有最终的“分类”层,其尺寸与预先训练的模型不同。例如,如果在Flowers上微调了ImageNet训练的模型,预先训练的logits图层将具有尺寸,[2048 x 1001]
但是我们的新logits图层将具有尺寸[2048 x 5]
Keep in mind that warm-starting from a checkpoint affects the model's weights only during the initialization of the model. Once a model has started training, a new checkpoint will be created in ${TRAIN_DIR}
. If the fine-tuning training is stopped and restarted, this new checkpoint will be the one from which weights are restored and not the ${checkpoint_path}$
. Consequently, the flags --checkpoint_path
and --checkpoint_exclude_scopes
are only used during the 0-
th global step (model initialization). Typically for fine-tuning one only want train a sub-set of layers, so the flag --trainable_scopes
allows to specify which subsets of layers should trained, the rest would remain frozen.
Below we give an example of fine-tuning inception-v3 on flowers, inception_v3 was trained on ImageNet with 1000 class labels, but the flowers dataset only have 5 classes. Since the dataset is quite small we will only train the new layers.
$ DATASET_DIR=/tmp/flowers $ TRAIN_DIR=/tmp/flowers-models/inception_v3 $ CHECKPOINT_PATH=/tmp/my_checkpoints/inception_v3.ckpt $ python train_image_classifier.py \ --train_dir=${TRAIN_DIR} \ --dataset_dir=${DATASET_DIR} \ --dataset_name=flowers \ --dataset_split_name=train \ --model_name=inception_v3 \ --checkpoint_path=${CHECKPOINT_PATH} \ --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \ --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits
七、Exporting the Inference Graph(导出推理图)
Saves out a GraphDef containing the architecture of the model.
To use it with a model name defined by slim, run:
$ python export_inference_graph.py \ --alsologtostderr \ --model_name=inception_v3 \ --output_file=/tmp/inception_v3_inf_graph.pb $ python export_inference_graph.py \ --alsologtostderr \ --model_name=mobilenet_v1 \ --image_size=224 \ --output_file=/tmp/mobilenet_v1_224.pb
Freezing the exported Graph(冻结导出的图形)
If you then want to use the resulting model with your own or pretrained checkpoints as part of a mobile model, you can run freeze_graph to get a graph def with the variables inlined as constants using:
bazel build tensorflow/python/tools:freeze_graph bazel-bin/tensorflow/python/tools/freeze_graph \ --input_graph=/tmp/inception_v3_inf_graph.pb \ --input_checkpoint=/tmp/checkpoints/inception_v3.ckpt \ --input_binary=true --output_graph=/tmp/frozen_inception_v3.pb \ --output_node_names=InceptionV3/Predictions/Reshape_1
通过freeze_graph把tf.train.write_graph()生成的pb文件与tf.train.saver() 生成的chkp文件固化之后重新生成一个pb文件
--input_graph=/path/to/graph.pb # 注意:这里的pb文件是用tf.train.write_graph方法保存的(所以此时只有图结构哦)
<span style="font-size:16px;"><span style="background-color:rgb(255,255,255);"><span class="repo-root js-repo-root" style="font-weight:600;color:rgb(88,96,105);font-family:'-apple-system', BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol';">
[html] view plain copy print?
<span style="font-size:16px;"><span style="background-color:rgb(255,255,255);"><span class="repo-root js-repo-root" style="font-weight:600;color:rgb(88,96,105);font-family:'-apple-system', BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol';"><span class="js-path-segment"><a style="color:rgb(3,102,214);" href="https://github.com/tensorflow/tensorflow" target="_blank">tensorflow</a></span></span><span class="separator" style="color:rgb(88,96,105);font-family:'-apple-system', BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol';">/</span><span class="js-path-segment" style="color:rgb(88,96,105);font-family:'-apple-system', BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol';"><a style="color:rgb(3,102,214);" href="https://github.com/tensorflow/tensorflow/tree/master/tensorflow" target="_blank">tensorflow</a></span><span class="separator" style="color:rgb(88,96,105);font-family:'-apple-system', BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol';">/</span><span class="js-path-segment" style="color:rgb(88,96,105);font-family:'-apple-system', BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol';"><a style="color:rgb(3,102,214);" href="https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python" target="_blank">python</a></span><span class="separator" style="color:rgb(88,96,105);font-family:'-apple-system', BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol';">/</span><span class="js-path-segment" style="color:rgb(88,96,105);font-family:'-apple-system', BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol';"><a style="color:rgb(3,102,214);" href="https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/tools" target="_blank">tools</a></span><span class="separator" style="color:rgb(88,96,105);font-family:'-apple-system', BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol';">/</span><span class="final-path" style="font-weight:600;color:rgb(36,41,46);font-family:'-apple-system', BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol';">freeze_graph.py原理:</span></span></span>
output_graph_def = graph_util.convert_variables_to_constants( # 关键函数
sess, input_graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist)
The output node names will vary depending on the model, but you can inspect and estimate them using the summarize_graph tool:
bazel build tensorflow/tools/graph_transforms:summarize_graph bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \ --in_graph=/tmp/inception_v3_inf_graph.pb