为了账号安全,请及时绑定邮箱和手机立即绑定

为什么我不能在 PyTorch 张量后面附加 torch.cat?

为什么我不能在 PyTorch 张量后面附加 torch.cat?

慕娘9325324 2022-09-13 19:51:18
我有:import torchinput_sliced = torch.rand(180, 161)output_sliced = torch.rand(180,)batched_inputs = torch.Tensor()batched_outputs = torch.Tensor()print('input_sliced.size', input_sliced.size())print('output_sliced.size', output_sliced.size())batched_inputs = torch.cat((batched_inputs, input_sliced))batched_outputs = torch.cat((batched_outputs, output_sliced))print('batched_inputs.size', batched_inputs.size())print('batched_outputs.size', batched_outputs.size())此输出:input_sliced.size torch.Size([180, 161])output_sliced.size torch.Size([180])batched_inputs.size torch.Size([180, 161])batched_outputs.size torch.Size([180])我需要附加那些,但不起作用。我做错了什么?batchedtorch.cat
查看完整描述

1 回答

?
Helenr

TA贡献1780条经验 获得超4个赞

假设你在循环中这样做,我会说最好这样做:


import torch


batch_input, batch_output = [], []

for i in range(10):  # assuming batch_size=10

    batch_input.append(torch.rand(180, 161))

    batch_output.append(torch.rand(180,))


batch_input = torch.stack(batch_input)

batch_output = torch.stack(batch_output)


print(batch_input.shape)   # output: torch.Size([10, 180, 161])

print(batch_output.shape)  # output: torch.Size([10, 180])

如果您先验地知道结果形状,则可以预先分配最终形状,只需将每个样品分配到批次中的相应位置即可。这将更节省内存。batch_*Tensor


查看完整回答
反对 回复 2022-09-13
  • 1 回答
  • 0 关注
  • 83 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
微信客服

购课补贴
联系客服咨询优惠详情

帮助反馈 APP下载

慕课网APP
您的移动学习伙伴

公众号

扫描二维码
关注慕课网微信公众号