为了账号安全,请及时绑定邮箱和手机立即绑定

调用函数中的变量batch_size

调用函数中的变量batch_size

富国沪深 2022-09-06 21:27:04
我正在尝试用TensorFlow 2实现一个注意力网络。因此,对于每个图像,我只想瞥见一些,即图像的一小部分。为此,我从tensorflow.keras.models.Model中实现了一个子类,这里有一个片段。class RecurrentAttentionModel(models.Model):# ...def call(self, inputs):    l = tf.random.uniform((40,2,), minval=0, maxval=1)    for _ in range(0, self.glimpses):        glimpse = tf.image.extract_glimpse(inputs, size=(self.retina_size, self.retina_size), offsets=l, centered=False, normalized=True)        # some other code...        # update l to take a glimpse somewhere else    return result           现在,上面的代码可以完美地工作和训练,但我的问题是,我有硬编码的40,这是我在数据集中定义的batch_size。我无法在调用方法中读取/获取batch_size,因为变量“inputs”的形式是batch_size似乎是预期行为。当我只用下面的代码初始化l(没有batch_size)Tensor("input_1_77:0", shape=(None, 250, 500, 1), dtype=float32)Nonel = tf.random.uniform((2,), minval=0, maxval=1)它抛出此错误ValueError: Shape must be rank 2 but is rank 1 for 'recurrent_attention_model_86/ExtractGlimpse' (op: 'ExtractGlimpse') with input shapes: [?,250,500,1], [2], [2]我完全理解,但我不知道如何根据batch_size实现初始值。
查看完整描述

1 回答

?
守着一只汪

TA贡献1872条经验 获得超3个赞

您可以使用 动态提取批大小维度。tf.shape

l = tf.random.normal(tf.stack([tf.shape(inputs)[0], 2]), minval=0, maxval=1))


查看完整回答
反对 回复 2022-09-06
  • 1 回答
  • 0 关注
  • 114 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信