1 回答

TA贡献1805条经验 获得超9个赞
你可以得到3维指数的最大值从max_idx。中的值max_idx是最大值沿轴 1 的索引。有六个值,因为您的其他轴是 3 和 2 (3 x 2 = 6)。您只需要了解 numpy 通过它们获取其他每个轴的索引的顺序。您首先遍历最后一个轴:
d0, d1, d2 = A.shape
a0 = [i for i in range(d0) for _ in range(d2)] # [0, 0, 1, 1, 2, 2]
a1 = max_idx.flatten() # [2, 2, 0, 2, 0, 1]
a2 = [k for _ in range(d0) for k in range(d2)] # [0, 1, 0, 1, 0, 1]
B[a0, a1, a2] = A[a0, a1, a2]
输出:
array([[[0. , 0. ],
[0. , 0. ],
[0.94485653, 0.9264881 ]],
[[0.95446736, 0. ],
[0. , 0. ],
[0. , 0.36436023]],
[[0.56911013, 0. ],
[0. , 0.96278067],
[0. , 0. ]]])
添加回答
举报