tf.function我正在尝试使用贪婪解码方法保存模型。该代码经过测试并按预期在急切模式(调试)下工作。但是,它在非急切执行中不起作用。该方法被调用namedtuple,Hyp如下所示:Hyp = namedtuple( 'Hyp', field_names='score, yseq, encoder_state, decoder_state, decoder_output')while 循环的调用方式如下:_, hyp = tf.while_loop( cond=condition_, body=body_, loop_vars=(tf.constant(0, dtype=tf.int32), hyp), shape_invariants=( tf.TensorShape([]), tf.nest.map_structure(get_shape_invariants, hyp), ))这是以下的相关部分body_:def body_(i_, hypothesis_: Hyp): # [:] Collapsed some code .. def update_from_next_id_(): return Hyp( # Update values .. ) # The only place where I generate a new hypothesis_ namedtuple hypothesis_ = tf.cond( tf.not_equal(next_id, blank), true_fn=lambda: update_from_next_id_(), false_fn=lambda: hypothesis_ ) return i_ + 1, hypothesis_我得到的是ValueError:ValueError: Input tensor 'hypotheses:0' enters the loop with shape (), but has shape <unknown> after one iteration. To allow the shape to vary across iterations, use the 形状不变量 argument of tf.while_loop to specify a less-specific shape.这里可能有什么问题?以下是如何input_signature定义tf.function我想序列化的。这self.greedy_decode_impl是实际的实现 - 我知道这有点难看,但这self.greedy_decode就是我所说的。
1 回答
MM们
TA贡献1886条经验 获得超2个赞
好吧,事实证明
tf.concat([hypothesis_.yseq, next_id], axis=0),
本来应该是
tf.concat([hypothesis_.yseq, next_id], axis=-1),
公平地说,错误消息有点提示您在哪里查看,但“有帮助”不足以描述它。我TensorSpec
通过连接错误的轴来违反了,仅此而已,但 Tensorflow 还无法直接指向受影响的张量。
添加回答
举报
0/150
提交
取消