为了账号安全,请及时绑定邮箱和手机立即绑定

Tensorflow Estimator:缓存瓶颈

Tensorflow Estimator:缓存瓶颈

有只小跳蛙 2021-06-04 10:53:11
在遵循 tensorflow 图像分类教程时,首先它会缓存每个图像的瓶颈:定义:cache_bottlenecks())我已经使用 tensorflow 的Estimator. 这确实简化了所有代码。但是我想在这里缓存瓶颈功能。这是我的model_fn. 我想缓存dense层的结果,这样我就可以对实际训练进行更改,而不必每次都计算瓶颈。我怎样才能做到这一点?def model_fn(features, labels, mode, params):    is_training = mode == tf.estimator.ModeKeys.TRAIN    num_classes = len(params['label_vocab'])    module = hub.Module(params['module_spec'], trainable=is_training and params['train_module'])    bottleneck_tensor = module(features['image'])    with tf.name_scope('final_retrain_ops'):        logits = tf.layers.dense(bottleneck_tensor, units=num_classes, trainable=is_training)  # save this?    def train_op_fn(loss):        optimizer = tf.train.AdamOptimizer()        return optimizer.minimize(loss, global_step=tf.train.get_global_step())    head = tf.contrib.estimator.multi_class_head(n_classes=num_classes, label_vocabulary=params['label_vocab'])    return head.create_estimator_spec(        features, mode, logits, labels, train_op_fn=train_op_fn    )
查看完整描述

2 回答

?
慕仙森

TA贡献1827条经验 获得超8个赞

TF 无法在您编码时工作。你应该:

  1. 从原始网络导出瓶颈到文件。

  2. 使用瓶颈结果作为输入,使用另一个网络来训练您的数据。


查看完整回答
反对 回复 2021-06-06
?
守着一只汪

TA贡献1872条经验 获得超3个赞

这样的事情应该工作(未经测试):


# Serialize the data into two tfrecord files

tf.enable_eager_execution()

feature_extractor = ...

features_file = tf.python_io.TFRecordWriter('features.tfrec')

label_file = tf.python_io.TFRecordWriter('labels.tfrec')


for images, labels in dataset:

  features = feature_extractor(images)

  features_file.write(tf.serialize_tensor(features))

  label_file.write(tf.serialize_tensor(labels))

# Parse the files and zip them together

def parse(type, shape):

  _def parse(x):

    result = tf.parse_tensor(x, out_type=shape)

    result = tf.reshape(result, FEATURE_SHAPE)

    return result

  return parse


features_ds = tf.data.TFRecordDataset('features.tfrec')

features_ds = features_ds.map(parse(tf.float32, FEATURE_SHAPE), num_parallel_calls=AUTOTUNE)


labels_ds = tf.data.TFRecordDataset('labels.tfrec')

labels_ds = labels_ds.map(parse(tf.float32, FEATURE_SHAPE), num_parallel_calls=AUTOTUNE)


ds = tf.data.Dataset.zip(features_ds, labels_ds)

ds = ds.unbatch().shuffle().repeat().batch().prefetch()...

您也可以使用 来完成它Dataset.cache,但我不是 100% 确定细节。


查看完整回答
反对 回复 2021-06-06
  • 2 回答
  • 0 关注
  • 169 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信