为了账号安全,请及时绑定邮箱和手机立即绑定

Pytorch张量获取具有特定值的元素的索引?

Pytorch张量获取具有特定值的元素的索引?

海绵宝宝撒 2023-10-26 16:36:28
我有两个张量,张量 a 和张量 b。我想获取张量 b 中值的所有索引。例如。a = torch.Tensor([1,2,2,3,4,4,4,5])b = torch.Tensor([1,2,4])1, 2, 4我想要张量 a的索引。我可以通过以下代码来做到这一点。a = torch.Tensor([1,2,2,3,4,4,4,5])b = torch.Tensor([1,2,4])mask = torch.zeros(a.shape).type(torch.bool)print(mask)for e in b:    mask = mask + (a == e)    print(mask)如果没有 ,我该怎么做for?
查看完整描述

2 回答

?
繁花不似锦

TA贡献1851条经验 获得超4个赞

由于 PyTorch1.10isin()isinf()以及许多其他 numpy 等效项)也可用,因此您可以简单地执行以下操作:

torch.isin(a, b)

这会给你:

Out[4]: tensor([ True,  True,  True, False,  True,  True,  True, False])

旧答案:

这是你想要的吗?:

np.in1d(a.numpy(), b.numpy())

将导致:

array([ True,  True,  True, False,  True,  True,  True, False])


查看完整回答
反对 回复 2023-10-26
?
拉风的咖菲猫

TA贡献1995条经验 获得超2个赞

如果您只是不想使用 for 循环,则可以使用列表理解:

mask = [a[index] for index in b]

如果甚至不想使用“for”一词,您可以随时将张量转换为 numpy 并使用 numpy 索引。

mask = torch.tensor(a.numpy()[b.numpy()])

更新

可能误解了你的问题。在这种情况下,我想说实现这一点的最佳方法是通过列表理解。(切片可能无法实现这一点。

mask = [index for index,value in enumerate(a) if value in b.tolist()]

这会迭代 a 中的每个元素,获取它们的索引和值,如果该值在 b 内,则获取索引。


查看完整回答
反对 回复 2023-10-26
  • 2 回答
  • 0 关注
  • 192 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信