为了账号安全,请及时绑定邮箱和手机立即绑定

在 MATLAB 上改进 Numpy 中复数矩阵求和元素的更好方法

在 MATLAB 上改进 Numpy 中复数矩阵求和元素的更好方法

哔哔one 2021-06-25 18:44:14
我试图将下面的 MATLAB 代码重写为 Python,发现我的 Python 代码(2.7 秒)比 MATLAB(1.2 秒)慢。我尝试了许多不同的方法,包括模块 numba,但还没有运气。如何使 Python 代码更快?MATLAB 代码:szA=[1024,1280]; HfszA=[512,640];[aPx,aPy]=meshgrid(-HfszA(2):HfszA(2)-1,-HfszA(1):HfszA(1)-1);img=randi(255,1024,1280);fx=rand(); fy=rand();ticfor i=1:20    F=abs(sum(sum(img.*exp(-1i*2*pi*(fx*aPx+fy*aPy)))));endtoc蟒蛇代码:import numpy as npimport timeszA=[1024,1280]; HfszA=[512,640]aPx,aPy=np.meshgrid(np.arange(-HfszA[1],HfszA[1]),np.arange(-HfszA[0],HfszA[0]))img=np.array(np.random.randint(256,size=(1024,1280)))fx=np.random.rand()fy=np.random.rand()start = time.time()for i in range(20):    F=abs(np.sum(img*np.exp(-1j*2*np.pi*(fx*aPx+fy*aPy))))end = time.time()print("Elapsed (after compilation) = %s" % (end - start))print(F)
查看完整描述

2 回答

?
胡说叔叔

TA贡献1804条经验 获得超8个赞

请始终发布您迄今为止尝试过的内容。关于您的 Numba 版本,我认为您做了一些导致性能不佳的事情。


例子


import numpy as np

import numba as nb

import time


@nb.njit(fastmath=True)

def your_function(fx,fy,aPx,aPy,img):

  pi=np.pi

  sum=0.

  for i in range(aPx.shape[0]):

    for j in range(aPx.shape[1]):

      sum+=img[i,j]*np.exp(-1j*2*pi*(fx*aPx[i,j]+fy*aPy[i,j]))

  return np.abs(sum)


@nb.njit(fastmath=True,parallel=True)

def your_function_p(fx,fy,aPx,aPy,img):

  pi=np.pi

  sum=0.

  for i in nb.prange(aPx.shape[0]):

    for j in range(aPx.shape[1]):

      sum+=img[i,j]*np.exp(-1j*2*pi*(fx*aPx[i,j]+fy*aPy[i,j]))

  return np.abs(sum)


#The function gets compiled at the first call

#you may also use cache=True, which only works in single threaded code

F=your_function(fx,fy,aPx,aPy,img)

start = time.time()

for i in range(20):

    F=your_function(fx,fy,aPx,aPy,img)


end = time.time()

print("Elapsed (after compilation) = %s" % (end - start))

print(F)


F=your_function_p(fx,fy,aPx,aPy,img)

start = time.time()

for i in range(20):

    F=your_function_p(fx,fy,aPx,aPy,img)


end = time.time()

print("Elapsed (after compilation) = %s" % (end - start))

print(F)

计时(4C/8T)


your_version: 2.45s

Numba single threaded: 0.17s

Numba parallel: 0.07s


查看完整回答
反对 回复 2021-06-29
  • 2 回答
  • 0 关注
  • 255 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
微信客服

购课补贴
联系客服咨询优惠详情

帮助反馈 APP下载

慕课网APP
您的移动学习伙伴

公众号

扫描二维码
关注慕课网微信公众号