我有两个张量,张量 a 和张量 b。我想获取张量 b 中值的所有索引。例如。a = torch.Tensor([1,2,2,3,4,4,4,5])b = torch.Tensor([1,2,4])1, 2, 4我想要张量 a的索引。我可以通过以下代码来做到这一点。a = torch.Tensor([1,2,2,3,4,4,4,5])b = torch.Tensor([1,2,4])mask = torch.zeros(a.shape).type(torch.bool)print(mask)for e in b: mask = mask + (a == e) print(mask)如果没有 ,我该怎么做for?
2 回答
繁花不似锦
由于 PyTorch
TA贡献1851条经验 获得超4个赞
由于 PyTorch1.10
和isin()
(isinf()
以及许多其他 numpy 等效项)也可用,因此您可以简单地执行以下操作:
torch.isin(a, b)
这会给你:
Out[4]: tensor([ True, True, True, False, True, True, True, False])
旧答案:
这是你想要的吗?:
np.in1d(a.numpy(), b.numpy())
将导致:
array([ True, True, True, False, True, True, True, False])
拉风的咖菲猫
TA贡献1995条经验 获得超2个赞
如果您只是不想使用 for 循环,则可以使用列表理解:
mask = [a[index] for index in b]
如果甚至不想使用“for”一词,您可以随时将张量转换为 numpy 并使用 numpy 索引。
mask = torch.tensor(a.numpy()[b.numpy()])
更新
可能误解了你的问题。在这种情况下,我想说实现这一点的最佳方法是通过列表理解。(切片可能无法实现这一点。
mask = [index for index,value in enumerate(a) if value in b.tolist()]
这会迭代 a 中的每个元素,获取它们的索引和值,如果该值在 b 内,则获取索引。
添加回答
举报
0/150
提交
取消