我正在尝试将一个简单的代码片段从 TensorFlow 1.x 转换为 TensorFlow 2:# ########## TensorFlow 1.x code: ##########import numpy as npimport tensorflow as tfcoefficients = np.array([[1.], [-10.], [25.]])w = tf.Variable(0, dtype=tf.float32)x = tf.placeholder(tf.float32, [3, 1])cost = (x[0][0] * (w**2)) + (x[1][0]*w) + (x[2][0])train = tf.train.GradientDescentOptimizer(0.05).minimize(cost)if __name__ == '__main__': init = tf.global_variables_initializer() session = tf.Session() session.run(init) for i in range(100): session.run(train, feed_dict={x: coefficients}) print(session.run(w))大部分旧的 API 已在 TF2 中被替换(例如GradientDescentOptimizer替换为keras.optimizers.SGD),并且我能够弄清楚如何重构我的大部分代码,但我不知道如何重构tf.placeholder以及feed_dict这两者如何交互。TF2 中是否简单地避免使用占位符?
1 回答
鸿蒙传说
TA贡献1865条经验 获得超7个赞
通常,您使用@tf.function占位符并将其转换为函数参数。sess.runthen 被替换为调用该函数。过去用于返回操作的东西(比如minimize)现在只在函数内部调用。这是转换后的代码片段:
coefficients = tf.constant([[1.], [-10.], [25.]])
w = tf.Variable(0.0)
@tf.function
def train(x):
cost = (x[0][0] * (w**2)) + (x[1][0]*w) + (x[2][0])
tf.compat.v1.train.GradientDescentOptimizer(0.05).minimize(cost, var_list=[w])
for i in range(100):
train(coefficients)
print(w)
正如您所提到的,train.GradientDescentOptimizer已弃用,因此升级该部分将需要更多更改。
添加回答
举报
0/150
提交
取消