1 回答
TA贡献1784条经验 获得超9个赞
我试图通过仅更改 model.fit() 操作中代码的最后一部分来复制和解析您在 TF 2.6 版本中的代码。这是编辑后的代码:
import tensorflow as tf
def dummy_image_float(w,h):
return tf.constant([0.,]*(h*w*3), shape=[1,w,h,3], dtype=tf.float32)
def dummy_result(w,h,nfeature):
return tf.constant([0,]*(h*w*nfeature), shape=[1,w,h,nfeature], dtype=tf.float32)
model = tf.keras.applications.ResNet101V2(
include_top=False,
#input_tensor=x1,
weights='imagenet',
input_shape=(224, 224, 3),
pooling=None
)
model.compile(optimizer='adam', loss="mean_squared_error", run_eagerly=True)
#train_ds = [ (dummy_image_float(224,224), dummy_result(7,7,2048)) ]
model.fit(dummy_image_float(224,224), dummy_result(7,7,2048), epochs=2)
添加回答
举报