1 回答
TA贡献1815条经验 获得超10个赞
使用很容易take_along_axis:
>>> np.take_along_axis(x, ind[:, None, :], 2)
array([[[97, 57, 28, 59, 24],
[67, 77, 94, 50, 97],
[89, 55, 25, 76, 56],
[21, 50, 1, 24, 88]],
[[54, 83, 81, 64, 12],
[89, 49, 26, 15, 97],
[94, 97, 55, 32, 79],
[24, 63, 15, 63, 40]],
[[64, 21, 84, 41, 99],
[43, 28, 85, 12, 9],
[10, 0, 48, 75, 98],
[22, 63, 37, 93, 94]]])
如果您使用的是 1.15 之前的 numpy,则可以执行以下操作:
>>> m,n,k = x.shape
>>> m,n,k = np.ogrid[:m, :n, :k]
>>> x[m,n,ind[:, None, :]]
array([[[97, 57, 28, 59, 24],
[67, 77, 94, 50, 97],
[89, 55, 25, 76, 56],
[21, 50, 1, 24, 88]],
[[54, 83, 81, 64, 12],
[89, 49, 26, 15, 97],
[94, 97, 55, 32, 79],
[24, 63, 15, 63, 40]],
[[64, 21, 84, 41, 99],
[43, 28, 85, 12, 9],
[10, 0, 48, 75, 98],
[22, 63, 37, 93, 94]]])
添加回答
举报