为了账号安全,请及时绑定邮箱和手机立即绑定

如何在 Tensorflow 中复制 PyTorch 的 nn.function.unfold 函数

如何在 Tensorflow 中复制 PyTorch 的 nn.function.unfold 函数

Cats萌萌 2023-12-29 15:52:59
我想用tensorflow重写pytorch的torch.nn.functional.unfold函数:#input x:[16, 1, 50, 36]x = torch.nn.functional.unfold(x, kernel_size=(5, 36), stride=3)#output x:[16, 180, 16]我尝试使用该功能tf.extract_image_patches():x = tf.extract_image_patches(x,ksizes=[1, 1,5, 98],strides=[1, 1, 3, 1], rates=[1, 1, 1, 1],padding='VALID')输入x.shape:[16,1,64,98]我得到输出x.shape:[16,1,20,490]然后我将 重塑X为[16,490,20],这正是我所期望的。但是当我输入数据时出现错误:UnimplementedError (see above for traceback): Only support ksizes across space.[[Node:hcn/ExtractImagePatches = ExtractImagePatches[T=DT_FLOAT, ksizes=[1, 1, 5, 98], padding="VALID", rates=[1, 1, 1, 1], strides=[1, 1, 3, 1], _device="/job:localhost/replica:0/task:0/device:GPU:0"](hcn/Reshape)]]我如何使用tensorflow重写pytorchtorch.nn.functional.unfold函数来更改X?
查看完整描述

1 回答

?
小怪兽爱吃肉

TA贡献1852条经验 获得超1个赞

x = tf.reshape(x, [16, 50, 36, 1])
x = tf.extract_image_patches(x, ksizes=[1, 4, 98, 1], strides=[1, 4, 1, 1], rates=[1, 1, 1, 1], padding='VALID')


查看完整回答
反对 回复 2023-12-29
  • 1 回答
  • 0 关注
  • 135 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信