为了账号安全,请及时绑定邮箱和手机立即绑定

使用 GradientTape 计算相对于某些张量的预测梯度

使用 GradientTape 计算相对于某些张量的预测梯度

噜噜哒 2022-09-06 15:56:30
我正在尝试在TensorFlow 2.0中用GP实现WGAN。要计算梯度损失,您需要计算与输入图像相关的预测的梯度。现在,为了使它更易于处理,它不是计算相对于所有输入图像的预测梯度,而是沿着原始和假数据点的线计算插值数据点,并将其用作输入。为了实现这一点,我首先开发了一个函数,它将进行一些预测并返回相对于某些输入图像的梯度。首先,我想过这样做,但它在急切模式下不起作用。因此,我现在正试图使用.compute_gradientstf.keras.backend.gradientsGradientTape以下是我用来测试内容的代码:from tensorflow.keras import backend as Kfrom tensorflow.keras.layers import *from tensorflow.keras.models import *import tensorflow as tfimport numpy as np# Comes from Generative Deep Learning by David Fosterclass RandomWeightedAverage(tf.keras.layers.Layer):    def __init__(self, batch_size):        super().__init__()        self.batch_size = batch_size    """Provides a (random) weighted average between real and generated image samples"""    def call(self, inputs):        alpha = K.random_uniform((self.batch_size, 1, 1, 1))        return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])# Dummy criticdef make_critic():    critic = Sequential()    inputShape = (28, 28, 1)    critic.add(Conv2D(32, (5, 5), padding="same", strides=(2, 2),        input_shape=inputShape))    critic.add(LeakyReLU(alpha=0.2))    critic.add(Conv2D(64, (5, 5), padding="same", strides=(2, 2)))    critic.add(LeakyReLU(alpha=0.2))    critic.add(Flatten())    critic.add(Dense(512))    critic.add(LeakyReLU(alpha=0.2))    critic.add(Dropout(0.3))    critic.add(Dense(1))    return critic# Gather dataset((X_train, _), (X_test, _)) = tf.keras.datasets.fashion_mnist.load_data()X_train = X_train.reshape(-1, 28, 28, 1)X_test = X_test.reshape(-1, 28, 28, 1)# Note that I am using test images as fake images for testing purposesinterpolated_img = RandomWeightedAverage(32)([X_train[0:32].astype("float"), X_test[32:64].astype("float")])# Compute gradients of the predictions with respect to the interpolated imagescritic = make_critic()with tf.GradientTape() as tape:    y_pred = critic(interpolated_img)gradients = tape.gradient(y_pred, interpolated_img)渐变即将成为 。我在这里错过了什么吗?None
查看完整描述

1 回答

?
开满天机

TA贡献1786条经验 获得超13个赞

相对于某些张量的预测梯度...我在这里错过了什么吗?


是的。您需要一个 :tape.watch(interpolated_img)


with tf.GradientTape() as tape:

    tape.watch(interpolated_img)

    y_pred = critic(interpolated_img)

GradientTape需要存储正向传递的中间值来计算梯度。通常,您需要渐变 WRT 变量。因此,它不会保留从张量开始的计算痕迹,可能是为了节省内存。


如果你想要一个渐变WRT一个张量,你需要明确地告诉.tape


查看完整回答
反对 回复 2022-09-06
  • 1 回答
  • 0 关注
  • 95 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信