让我们称呼我正在寻找的函数“ magic_combine”,它可以组合我赋予它的张量的连续尺寸。更具体地说,我希望它执行以下操作:a = torch.zeros(1, 2, 3, 4, 5, 6) b = a.magic_combine(2, 5) # combine dimension 2, 3, 4 print(b.size()) # should be (1, 2, 60, 6)我知道torch.view()可以做类似的事情。但我只是想知道是否还有其他更优雅的方法可以达成目标?
2 回答

犯罪嫌疑人X
TA贡献2080条经验 获得超4个赞
我不确定“更优雅的方式”是什么想法,但是Tensor.view()优点是不为视图重新分配数据(原始张量和视图共享相同的数据),从而使此操作轻巧。
如@UmangGupta所述,但是包装此函数以实现您想要的内容很简单,例如:
import torch
def magic_combine(x, dim_begin, dim_end):
combined_shape = list(x.shape[:dim_begin]) + [-1] + list(x.shape[dim_end:])
return x.view(combined_shape)
a = torch.zeros(1, 2, 3, 4, 5, 6)
b = magic_combine(a, 2, 5) # combine dimension 2, 3, 4
print(b.size())
# torch.Size([1, 2, 60, 6])
添加回答
举报
0/150
提交
取消