为了账号安全,请及时绑定邮箱和手机立即绑定

使用 estimator api 避免 tf.data.Dataset.from_tenso

使用 estimator api 避免 tf.data.Dataset.from_tenso

忽然笑 2021-06-10 14:01:07
我正在尝试找出将datasetapi 与 api 一起使用的推荐方法estimator。我在网上看到的一切都是这个的一些变体:def train_input_fn():   dataset = tf.data.Dataset.from_tensor_slices((features, labels))   return dataset然后可以将其传递给估算器的 train 函数: classifier.train(    input_fn=train_input_fn,    #... )但数据集指南警告说:上面的代码片段会将特征和标签数组作为 tf.constant() 操作嵌入到您的 TensorFlow 图中。这适用于小数据集,但会浪费内存——因为数组的内容将被多次复制——并且可能会遇到 tf.GraphDef 协议缓冲区的 2GB 限制。然后描述一种方法,该方法涉及定义占位符,然后用 填充feed_dict:features_placeholder = tf.placeholder(features.dtype, features.shape)labels_placeholder = tf.placeholder(labels.dtype, labels.shape)dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))sess.run(iterator.initializer, feed_dict={features_placeholder: features,                                          labels_placeholder: labels})但是,如果您使用的是estimatorapi,则不会手动运行会话。那么如何将datasetapi 与 estimators 一起使用,同时避免与 相关的问题from_tensor_slices()?
查看完整描述

1 回答

?
万千封印

TA贡献1891条经验 获得超3个赞

要使用可初始化或可重新初始化的迭代器,您必须创建一个继承自 tf.train.SessionRunHook 的类,该类可以在训练和评估步骤中多次访问会话。


然后,您可以使用这个新类来初始化迭代器,您通常会在经典设置中执行此操作。您只需要将这个新创建的钩子传递给训练/评估函数或正确的训练规范。


这是您可以适应您的需求的快速示例:


class IteratorInitializerHook(tf.train.SessionRunHook):

    def __init__(self):

        super(IteratorInitializerHook, self).__init__()

        self.iterator_initializer_func = None # Will be set in the input_fn


    def after_create_session(self, session, coord):

        # Initialize the iterator with the data feed_dict

        self.iterator_initializer_func(session) 



def get_inputs(X, y):

    iterator_initializer_hook = IteratorInitializerHook()


    def input_fn():

        X_pl = tf.placeholder(X.dtype, X.shape)

        y_pl = tf.placeholder(y.dtype, y.shape)


        dataset = tf.data.Dataset.from_tensor_slices((X_pl, y_pl))

        dataset = ...

        ...


        iterator = dataset.make_initializable_iterator()

        next_example, next_label = iterator.get_next()



        iterator_initializer_hook.iterator_initializer_func = lambda sess: sess.run(iterator.initializer,

                                                                                    feed_dict={X_pl: X, y_pl: y})


        return next_example, next_label


    return input_fn, iterator_initializer_hook


...


train_input_fn, train_iterator_initializer_hook = get_inputs(X_train, y_train)

test_input_fn, test_iterator_initializer_hook = get_inputs(X_test, y_test)


...


estimator.train(input_fn=train_input_fn,

                hooks=[train_iterator_initializer_hook]) # Don't forget to pass the hook !

estimator.evaluate(input_fn=test_input_fn,

                   hooks=[test_iterator_initializer_hook])



查看完整回答
反对 回复 2021-06-15
  • 1 回答
  • 0 关注
  • 149 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
微信客服

购课补贴
联系客服咨询优惠详情

帮助反馈 APP下载

慕课网APP
您的移动学习伙伴

公众号

扫描二维码
关注慕课网微信公众号