TensorFlow-Slim image classification model library
TF-slim is a new lightweight high-level API of TensorFlow (tensorflow.contrib.slim
) for defining, training and evaluating complex models. This directory contains code for training and evaluating several widely used Convolutional Neural Network (CNN) image classification models using TF-slim. It contains scripts that will allow you to train models from scratch or fine-tune them from pre-trained network weights. It also contains code for downloading standard image datasets, converting them to TensorFlow's native TFRecord format and reading them in using TF-Slim's data reading and queueing utilities. You can easily train any model on any of these datasets, as we demonstrate below. We've also included a jupyter notebook, which provides working examples of how to use TF-Slim for image classification. For developing or modifying your own models, see also the main TF-Slim page.
TensorFlow-Slim图像分类模型库
TF-slim 是TensorFlow(tensorflow.contrib.slim
)的一个新的轻量级高级API,用于定义,训练和评估复杂模型。该目录包含用于训练和评估使用TF-slim的几种广泛使用的卷积神经网络(CNN)图像分类模型的代码。它包含的脚本可以让您从头开始训练模型,或者根据预先训练好的网络权重对其进行微调。它还包含下载标准图像数据集的代码,将它们转换为TensorFlow的原生TFRecord格式,并使用TF-Slim的数据读取和排队实用程序读取它们。您可以轻松地在任何这些数据集上训练任何模型,如下所示。我们还包括一个 jupyter笔记本,它提供了如何使用TF-Slim进行图像分类的实例。要开发或修改自己的模型,请参阅主TF-Slim页面。
一、安装
安装最新版本的TF-slim(这是TensorFlow自带的)
TF-Slim tf.contrib.slim
通过TensorFlow 1.0提供。要测试您的安装是否正常,请执行以下命令; 它应该运行而不会产生任何错误。
python -c "import tensorflow.contrib.slim as slim; eval = slim.evaluation.evaluate_once" # python -c 表示在命令行的执行python指令
安装TF-slim图像模型库(model库需要单独下载)
要使用TF-Slim进行图像分类,还必须安装TF-Slim图像模型库,该库不是核心TF库的一部分。要做到这一点,请按如下方式检查张量 流/模型库:
git clone https://github.com/tensorflow/models/
五、Fine-tuning a model from an existing checkpoint(从现有检查点微调模型)
Rather than training from scratch, we'll often want to start from a pre-trained model and fine-tune it. To indicate a checkpoint from which to fine-tune, we'll call training with the --checkpoint_path
flag and assign it an absolute path to a checkpoint file.
When fine-tuning a model, we need to be careful about restoring checkpoint weights. In particular, when we fine-tune a model on a new task with a different number of output labels, we wont be able restore the final logits (classifier) layer. For this, we'll use the --checkpoint_exclude_scopes
flag. This flag hinders certain variables from being loaded. When fine-tuning on a classification task using a different number of classes than the trained model, the new model will have a final 'logits' layer whose dimensions differ from the pre-trained model. For example, if fine-tuning an ImageNet-trained model on Flowers, the pre-trained logits layer will have dimensions [2048 x 1001]
but our new logits layer will have dimensions [2048 x 5]
. Consequently, this flag indicates to TF-Slim to avoid loading these weights from the checkpoint.
在微调模型时,我们需要注意恢复检查点权重。特别是,当我们用一个不同数量的输出标签对新任务进行微调时,我们无法恢复最终的logits(分类器)层。为此,我们将使用该--checkpoint_exclude_scopes
标志。此标志阻止某些变量被加载。当使用与训练模型不同数量的分类任务进行微调时,新模型将具有最终的“分类”层,其尺寸与预先训练的模型不同。例如,如果在Flowers上微调了ImageNet训练的模型,预先训练的logits图层将具有尺寸,[2048 x 1001]
但是我们的新logits图层将具有尺寸[2048 x 5]
。因此,该标志向TF-Slim指示避免从检查点加载这些权重。
Keep in mind that warm-starting from a checkpoint affects the model's weights only during the initialization of the model. Once a model has started training, a new checkpoint will be created in ${TRAIN_DIR}
. If the fine-tuning training is stopped and restarted, this new checkpoint will be the one from which weights are restored and not the ${checkpoint_path}$
. Consequently, the flags --checkpoint_path
and --checkpoint_exclude_scopes
are only used during the 0-
th global step (model initialization). Typically for fine-tuning one only want train a sub-set of layers, so the flag --trainable_scopes
allows to specify which subsets of layers should trained, the rest would remain frozen.
因此,标志--checkpoint_path
和--checkpoint_exclude_scopes
期间仅用于0-
第全球步骤(模型初始化)。通常情况下,微调只需要训练一组子层,因此该标志--trainable_scopes
允许指定层的哪些子层应该训练,其余的将保持冻结。
Below we give an example of fine-tuning inception-v3 on flowers, inception_v3 was trained on ImageNet with 1000 class labels, but the flowers dataset only have 5 classes. Since the dataset is quite small we will only train the new layers.
$ DATASET_DIR=/tmp/flowers $ TRAIN_DIR=/tmp/flowers-models/inception_v3 $ CHECKPOINT_PATH=/tmp/my_checkpoints/inception_v3.ckpt $ python train_image_classifier.py \ --train_dir=${TRAIN_DIR} \ --dataset_dir=${DATASET_DIR} \ --dataset_name=flowers \ --dataset_split_name=train \ --model_name=inception_v3 \ --checkpoint_path=${CHECKPOINT_PATH} \ --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \ --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits
七、Exporting the Inference Graph(导出推理图)
Saves out a GraphDef containing the architecture of the model.
To use it with a model name defined by slim, run:
保存包含模型体系结构的GraphDef(.pb)。
要将其与由slim定义的模型名称一起使用,请运行:
$ python export_inference_graph.py \ --alsologtostderr \ --model_name=inception_v3 \ --output_file=/tmp/inception_v3_inf_graph.pb $ python export_inference_graph.py \ --alsologtostderr \ --model_name=mobilenet_v1 \ --image_size=224 \ --output_file=/tmp/mobilenet_v1_224.pb
Freezing the exported Graph(冻结导出的图形)
If you then want to use the resulting model with your own or pretrained checkpoints as part of a mobile model, you can run freeze_graph to get a graph def with the variables inlined as constants using:
如果您希望将自己或预训练检查点的结果模型用作移动模型的一部分,则可以使用以下命令运行freeze_graph以获取图形def,并将变量内联为常量:
bazel build tensorflow/python/tools:freeze_graph bazel-bin/tensorflow/python/tools/freeze_graph \ --input_graph=/tmp/inception_v3_inf_graph.pb \ --input_checkpoint=/tmp/checkpoints/inception_v3.ckpt \ --input_binary=true --output_graph=/tmp/frozen_inception_v3.pb \ --output_node_names=InceptionV3/Predictions/Reshape_1
通过freeze_graph把tf.train.write_graph()生成的pb文件与tf.train.saver() 生成的chkp文件固化之后重新生成一个pb文件
bazel-bin/tensorflow/python/tools/freeze_graph
--input_graph=/path/to/graph.pb # 注意:这里的pb文件是用tf.train.write_graph方法保存的(所以此时只有图结构哦)
--input_checkpoint=/path/to/model.ckpt
--output_node_names=output/predict
--output_graph=/path/to/frozen.pb
<span style="font-size:16px;"><span style="background-color:rgb(255,255,255);"><span class="repo-root js-repo-root" style="font-weight:600;color:rgb(88,96,105);font-family:'-apple-system', BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol';">
</span></span></span>
[html] view plain copy print?
<span style="font-size:16px;"><span style="background-color:rgb(255,255,255);"><span class="repo-root js-repo-root" style="font-weight:600;color:rgb(88,96,105);font-family:'-apple-system', BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol';"><span class="js-path-segment"><a style="color:rgb(3,102,214);" href="https://github.com/tensorflow/tensorflow" target="_blank">tensorflow</a></span></span><span class="separator" style="color:rgb(88,96,105);font-family:'-apple-system', BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol';">/</span><span class="js-path-segment" style="color:rgb(88,96,105);font-family:'-apple-system', BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol';"><a style="color:rgb(3,102,214);" href="https://github.com/tensorflow/tensorflow/tree/master/tensorflow" target="_blank">tensorflow</a></span><span class="separator" style="color:rgb(88,96,105);font-family:'-apple-system', BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol';">/</span><span class="js-path-segment" style="color:rgb(88,96,105);font-family:'-apple-system', BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol';"><a style="color:rgb(3,102,214);" href="https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python" target="_blank">python</a></span><span class="separator" style="color:rgb(88,96,105);font-family:'-apple-system', BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol';">/</span><span class="js-path-segment" style="color:rgb(88,96,105);font-family:'-apple-system', BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol';"><a style="color:rgb(3,102,214);" href="https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/tools" target="_blank">tools</a></span><span class="separator" style="color:rgb(88,96,105);font-family:'-apple-system', BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol';">/</span><span class="final-path" style="font-weight:600;color:rgb(36,41,46);font-family:'-apple-system', BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol';">freeze_graph.py原理:</span></span></span>
tensorflow/tensorflow/python/tools/freeze_graph.py原理:
output_graph_def = graph_util.convert_variables_to_constants( # 关键函数
sess, input_graph_def, output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist)
The output node names will vary depending on the model, but you can inspect and estimate them using the summarize_graph tool:
输出节点名称会根据模型而有所不同,但您可以使用summarize_graph工具检查并估计它们:
bazel build tensorflow/tools/graph_transforms:summarize_graph bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \ --in_graph=/tmp/inception_v3_inf_graph.pb
共同学习,写下你的评论
评论加载中...
作者其他优质文章