我正在尝试为人员重新识别任务编写一个自定义损失函数,该函数在多任务学习设置和对象检测中进行训练。过滤后的标签值的形状为(batch_size, num_boxes)。我想创建一个掩码,以便仅考虑在暗淡 1 中重复的值进行进一步计算。如何在 TF/Keras 后端执行此操作?简短示例:Input labels = [[0,0,0,0,12,12,3,3,4], [0,0,10,10,10,12,3,3,4]]
Required output: [[0,0,0,0,1,1,1,1,0],[0,0,1,1,1,0,1,1,0]](基本上我只想过滤掉重复项并丢弃损失函数的唯一标识)。我想可以使用 tf.unique 和 tf.scatter 的组合,但我不知道如何使用。
1 回答
森林海
TA贡献2011条经验 获得超2个赞
这段代码的工作原理:
x = tf.constant([[0,0,0,0,12,12,3,3,4], [0,0,10,10,10,12,3,3,4]])
def mark_duplicates_1D(x):
y, idx, count = tf.unique_with_counts(x)
comp = tf.math.greater(count, 1)
comp = tf.cast(comp, tf.int32)
res = tf.gather(comp, idx)
mult = tf.math.not_equal(x, 0)
mult = tf.cast(mult, tf.int32)
res *= mult
return res
res = tf.map_fn(fn=mark_duplicates_1D, elems=x)
添加回答
举报
0/150
提交
取消