我正在尝试用TensorFlow 2实现一个注意力网络。因此,对于每个图像,我只想瞥见一些,即图像的一小部分。为此,我从tensorflow.keras.models.Model中实现了一个子类,这里有一个片段。class RecurrentAttentionModel(models.Model):# ...def call(self, inputs): l = tf.random.uniform((40,2,), minval=0, maxval=1) for _ in range(0, self.glimpses): glimpse = tf.image.extract_glimpse(inputs, size=(self.retina_size, self.retina_size), offsets=l, centered=False, normalized=True) # some other code... # update l to take a glimpse somewhere else return result 现在,上面的代码可以完美地工作和训练,但我的问题是,我有硬编码的40,这是我在数据集中定义的batch_size。我无法在调用方法中读取/获取batch_size,因为变量“inputs”的形式是batch_size似乎是预期行为。当我只用下面的代码初始化l(没有batch_size)Tensor("input_1_77:0", shape=(None, 250, 500, 1), dtype=float32)Nonel = tf.random.uniform((2,), minval=0, maxval=1)它抛出此错误ValueError: Shape must be rank 2 but is rank 1 for 'recurrent_attention_model_86/ExtractGlimpse' (op: 'ExtractGlimpse') with input shapes: [?,250,500,1], [2], [2]我完全理解,但我不知道如何根据batch_size实现初始值。
1 回答
守着一只汪
TA贡献1872条经验 获得超3个赞
您可以使用 动态提取批大小维度。tf.shape
l = tf.random.normal(tf.stack([tf.shape(inputs)[0], 2]), minval=0, maxval=1))
添加回答
举报
0/150
提交
取消