我需要一个函数来获取 numpy 数组中沿轴的最后一个元素。例如,如果我有一个数组,a = np.array([1, 2, 3])该功能应该像get_last_elements(a, axis=0)>>> [3]get_last_elements(a, axis=1)>>> [1, 2, 3]此函数也需要适用于多维数组:b = np.array([[1, 2], [3, 4]])get_last_elements(b, axis=0)>>> [[2], [4]]get_last_elements(b, axis=1)>>> [3, 4]有没有人有实现它的好主意?
1 回答
Helenr
TA贡献1780条经验 获得超3个赞
您可以使用np.take它:
def get_last_elements(a, axis=0):
shape = list(a.shape)
shape[axis] = 1
return np.take(a,-1,axis=axis).reshape(tuple(shape))
输出:
print(get_last_elements(b, axis=0))
[[3 4]]
print(get_last_elements(b, axis=1))
[[2]
[4]]
添加回答
举报
0/150
提交
取消