例如,代码是input = torch.randn(3, 10)result = torch.argmax(input, dim=0, keepdim=True)input 是tensor([[ 1.5742, 0.8183, -2.3005, -1.1650, -0.2451], [ 1.0553, 0.6021, -0.4938, -1.5379, -1.2054], [-0.1728, 0.8372, -1.9181, -0.9110, 0.2422]])并且result是tensor([[ 0, 2, 1, 2, 2]])但是,我想要这样的结果tensor([[ 1, 0, 0, 0, 0], [ 0, 0, 1, 0, 0], [ 0, 1, 0, 1, 1]])
添加回答
举报
0/150
提交
取消