为了账号安全,请及时绑定邮箱和手机立即绑定

将 2d 数组与 3d 数组的每个切片相乘 - Numpy

将 2d 数组与 3d 数组的每个切片相乘 - Numpy

BIG阳 2021-06-29 13:57:46
我正在寻找一种优化的方法来计算 2d 数组与 3d 数组的每个切片的元素乘法(使用 numpy)。例如:w = np.array([[1,5], [4,9], [12,15]]) y = np.ones((3,2,3))我想得到一个 3d 数组的结果,其形状与y.不允许使用 * 运算符进行广播。就我而言,第三维很长,for 循环不方便。
查看完整描述

1 回答

?
慕后森

TA贡献1802条经验 获得超5个赞

给定数组


import numpy as np


w = np.array([[1,5], [4,9], [12,15]])


print(w)


[[ 1  5]

 [ 4  9]

 [12 15]]


y = np.ones((3,2,3))


print(y)


[[[ 1.  1.  1.]

  [ 1.  1.  1.]]


 [[ 1.  1.  1.]

  [ 1.  1.  1.]]


 [[ 1.  1.  1.]

  [ 1.  1.  1.]]]

我们可以直接对数组进行乘法运算,


z = ( y.transpose() * w.transpose() ).transpose()


print(z)


[[[  1.   1.   1.]

  [  5.   5.   5.]]


 [[  4.   4.   4.]

  [  9.   9.   9.]]


 [[ 12.  12.  12.]

  [ 15.  15.  15.]]]

我们可能会注意到,这会产生与 np.einsum('ij,ijk->ijk',w,y) 相同的结果,可能需要更少的努力和开销。


查看完整回答
反对 回复 2021-07-13
  • 1 回答
  • 0 关注
  • 185 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信