为了账号安全,请及时绑定邮箱和手机立即绑定

在 Tensorflow/Keras 中为重复元素创建掩码

在 Tensorflow/Keras 中为重复元素创建掩码

小唯快跑啊 2023-12-05 15:36:01
我正在尝试为人员重新识别任务编写一个自定义损失函数,该函数在多任务学习设置和对象检测中进行训练。过滤后的标签值的形状为(batch_size, num_boxes)。我想创建一个掩码,以便仅考虑在暗淡 1 中重复的值进行进一步计算。如何在 TF/Keras 后端执行此操作?简短示例:Input labels = [[0,0,0,0,12,12,3,3,4], [0,0,10,10,10,12,3,3,4]] Required output: [[0,0,0,0,1,1,1,1,0],[0,0,1,1,1,0,1,1,0]](基本上我只想过滤掉重复项并丢弃损失函数的唯一标识)。我想可以使用 tf.unique 和 tf.scatter 的组合,但我不知道如何使用。
查看完整描述

1 回答

?
森林海

TA贡献2011条经验 获得超2个赞

这段代码的工作原理:


x = tf.constant([[0,0,0,0,12,12,3,3,4], [0,0,10,10,10,12,3,3,4]])

def mark_duplicates_1D(x):

  y, idx, count = tf.unique_with_counts(x)

  comp = tf.math.greater(count, 1)

  comp = tf.cast(comp, tf.int32)

  res = tf.gather(comp, idx)

  mult = tf.math.not_equal(x, 0)

  mult = tf.cast(mult, tf.int32)

  res *= mult

  return res

res = tf.map_fn(fn=mark_duplicates_1D, elems=x)


查看完整回答
反对 回复 2023-12-05
  • 1 回答
  • 0 关注
  • 106 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信