为了账号安全,请及时绑定邮箱和手机立即绑定

Python 中“矩阵乘法”的更快定义

Python 中“矩阵乘法”的更快定义

慕尼黑的夜晚无繁华 2023-02-15 15:23:19
我需要从头开始定义矩阵乘法,而不是将每个常数相乘,每个常数实际上是另一个数组,任何两个数组都需要“卷积”在一起(我认为没有必要在这里定义卷积) .我制作了一张图片,希望能更好地解释我想说的话:我必须使用的代码是这样的:for row in range(arr1.shape[2]):    for column in range(arr2.shape[3]):        for index in range(arr2.shape[2]): # Could also be "arr1.shape[3]"            out[:, :, row, column] += convolve(                arr2[:, :, :  , column][:, :, index],                arr1[:, :, row, :     ][:, :, index]            )然而,这种方法对我来说非常慢,所以我想知道是否有更快的方法来做到这一点。
查看完整描述

1 回答

?
慕容森

TA贡献1853条经验 获得超18个赞

如果中间适合内存,则以下内容应该相当有效


import numpy as np

from scipy.signal import fftconvolve,convolve


# example

rng = np.random.default_rng()

A = rng.random((5,6,2,3))                    

B = rng.random((4,3,3,4))                    


# custom matmul


Ae,Be = A[...,None],B[:,:,None]

shsh = np.maximum(Ae.shape[2:],Be.shape[2:])

Ae = np.broadcast_to(Ae,(*Ae.shape[:2],*shsh))

Be = np.broadcast_to(Be,(*Be.shape[:2],*shsh))

C = fftconvolve(Ae,Be,axes=(0,1),mode='valid').sum(3)         


# original loop for reference


out = np.zeros_like(C)

for row in range(A.shape[2]):

    for column in range(B.shape[3]):

        for index in range(B.shape[2]): # Could also be "A.shape[3]"

            out[:, :, row, column] += convolve(

                B[:, :, :  , column][:, :, index],

                A[:, :, row, :     ][:, :, index],

                mode='valid'

            )


print(np.allclose(C,out))


# True

通过批量进行卷积,我们减少了我们必须做的 fft 的总数。


如果需要,可以通过使用 对傅里叶空间进行总和缩减,进一步优化速度和内存einsum。不过,这需要手动进行 fft 卷积。


查看完整回答
反对 回复 2023-02-15
  • 1 回答
  • 0 关注
  • 99 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信