为了账号安全,请及时绑定邮箱和手机立即绑定

Tensorflow Keras:评估时如何在自定义层中设置断点(调试)?

Tensorflow Keras:评估时如何在自定义层中设置断点(调试)?

鸿蒙传说 2021-05-13 18:13:25
我只想在自定义层中进行一些数值验证。假设我们有一个非常简单的自定义层:class test_layer(keras.layers.Layer):    def __init__(self, **kwargs):        super(test_layer, self).__init__(**kwargs)    def build(self, input_shape):        self.w = K.variable(1.)        self._trainable_weights.append(self.w)        super(test_layer, self).build(input_shape)    def call(self, x, **kwargs):        m = x * x            # Set break point here        n = self.w * K.sqrt(x)        return m + n和主程序:import tensorflow as tfimport kerasimport keras.backend as Kinput = keras.layers.Input((100,1))y = test_layer()(input)model = keras.Model(input,y)model.predict(np.ones((100,1)))如果我在该行上设置了断点调试,则m = x * x执行时程序将在此处暂停y = test_layer()(input),这是因为生成了图形,因此call()调用了该方法。但是当我使用model.predict()它来赋予它真正的价值,并且想在图层内部查看它是否工作正常时,它并不会停在那一行m = x * x我的问题是:是call()计算图形正在兴建时,方法只叫什么名字?(提供实际价值时不会调用它吗?)给实数输入时,如何调试(或在何处插入断点)以查看变量的值?
查看完整描述

2 回答

?
holdtom

TA贡献1805条经验 获得超10个赞

  1. 是的。该call()方法仅用于构建计算图。

  2. 至于调试。我更喜欢使用TFDBG,这是针对tensorflow的推荐调试工具,尽管它不提供断点功能。

对于Keras,您可以将以下行添加到脚本中以使用TFDBG

import tf.keras.backend as K

from tensorflow.python import debug as tf_debug

sess = K.get_session()

sess = tf_debug.LocalCLIDebugWrapperSession(sess)

K.set_session(sess)


查看完整回答
反对 回复 2021-05-18
  • 2 回答
  • 0 关注
  • 382 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
微信客服

购课补贴
联系客服咨询优惠详情

帮助反馈 APP下载

慕课网APP
您的移动学习伙伴

公众号

扫描二维码
关注慕课网微信公众号