我正在使用 cupy 通过 pytorch 运行 cuda 代码。我的环境是ubuntu 20,anaconda-python 3.7.6,nvidia-driver 440,cuda 10.2,cupy-cuda102,torch 1.4.0首先我写了一个简单的主要代码import data_load_testfrom tqdm import tqdmimport torchfrom torch.utils.data import DataLoaderdef main(): dataset = data_load_test.DataLoadTest() training_loader = DataLoader(dataset, batch_size=1) with torch.cuda.device(0): pbar = tqdm(training_loader) for epoch in range(3): for i, img in enumerate(pbar): print("see the message")if __name__ == "__main__": main()和这样的数据加载器。from torch.utils.data import Datasetimport cv2import cupy as cpdef read_cuda_file(cuda_path): f = open(cuda_path, 'r') source_line = "" while True: line = f.readline() if not line: break source_line = source_line + line f.close() return source_lineclass DataLoadTest(Dataset): def __init__(self): source = read_cuda_file("cuda/cuda_code.cu") cuda_source = '''{}'''.format(source) module = cp.RawModule(code=cuda_source) self.myfunc = module.get_function('myfunc') self.input = cp.asarray(cv2.imread("hi.png",-1), cp.uint8) h, w, c = self.input.shape self.h = h self.w = w self.output = cp.zeros((w, h, 3), dtype=cp.uint8) self.block_size = (32, 32) self.grid_size = (h // self.block_size[1], w // self.block_size[0]) def __len__(self): return 1 def __getitem__(self, idx): self.myfunc(self.grid_size, self.block_size, (self.input, self.output, self.h, self.w)) return cp.asnumpy(self.output)
1 回答
慕田峪7331174
TA贡献1828条经验 获得超13个赞
在 main() 中,当实例化 dataLoadTest() 类时,它发生在默认设备 0 上,因此 cuPy 在那里编译 myFunc() 。
下一行“with torch.cuda.device(0):”是在失败的版本中切换到设备1的位置?
如果你打电话会发生什么
cuPy.cuda.Device(1).use()
作为 main() 中的第一行,以确保 myFunc() 在设备 1 上实例化?
添加回答
举报
0/150
提交
取消