我只想在自定义层中进行一些数值验证。假设我们有一个非常简单的自定义层:class test_layer(keras.layers.Layer): def __init__(self, **kwargs): super(test_layer, self).__init__(**kwargs) def build(self, input_shape): self.w = K.variable(1.) self._trainable_weights.append(self.w) super(test_layer, self).build(input_shape) def call(self, x, **kwargs): m = x * x # Set break point here n = self.w * K.sqrt(x) return m + n和主程序:import tensorflow as tfimport kerasimport keras.backend as Kinput = keras.layers.Input((100,1))y = test_layer()(input)model = keras.Model(input,y)model.predict(np.ones((100,1)))如果我在该行上设置了断点调试,则m = x * x执行时程序将在此处暂停y = test_layer()(input),这是因为生成了图形,因此call()调用了该方法。但是当我使用model.predict()它来赋予它真正的价值,并且想在图层内部查看它是否工作正常时,它并不会停在那一行m = x * x我的问题是:是call()计算图形正在兴建时,方法只叫什么名字?(提供实际价值时不会调用它吗?)给实数输入时,如何调试(或在何处插入断点)以查看变量的值?
2 回答

holdtom
TA贡献1805条经验 获得超10个赞
是的。该
call()
方法仅用于构建计算图。至于调试。我更喜欢使用
TFDBG
,这是针对tensorflow的推荐调试工具,尽管它不提供断点功能。
对于Keras,您可以将以下行添加到脚本中以使用TFDBG
import tf.keras.backend as K
from tensorflow.python import debug as tf_debug
sess = K.get_session()
sess = tf_debug.LocalCLIDebugWrapperSession(sess)
K.set_session(sess)
添加回答
举报
0/150
提交
取消