我是 Tensorflow 的新手。我有一个图像数据集,其中一张图像有多个标签。据我了解,我需要使用tf.losses.sigmoid_cross_entropy(). 我尝试应用于tf.one_hot标签,但是当我尝试将它们传递给损失函数时,我得到错误,形状不兼容。我怎样才能解决这个问题?
1 回答
HUWWW
TA贡献1874条经验 获得超12个赞
你是对的tf.losses.sigmoid_cross_entropy
。所有你需要做的就是 wrap tf.one_hot
withtf.reduce_max
来减少这样的维度。
tf.reduce_max(tf.one_hot(labels, num_classes, dtype=tf.int32), axis=0)
那应该返回 shape 的张量(num_classes,)
,正是您的损失函数所需要的。
添加回答
举报
0/150
提交
取消