为了账号安全,请及时绑定邮箱和手机立即绑定

计算火炬张量数组的平均值和标准差

计算火炬张量数组的平均值和标准差

莫回无 2023-11-09 22:17:31
我正在尝试计算火炬张量数组的平均值和标准差。我的数据集有 720 张训练图像,每张图像都有 4 个地标,其中 X 和 Y 代表图像上的 2D 点。to_tensor = transforms.ToTensor()landmarks_arr = []for i in range(len(train_dataset)):    landmarks_arr.append(to_tensor(train_dataset[i]['landmarks']))                     mean = torch.mean(torch.stack(landmarks_arr, dim=0))#, dim=(0, 2, 3))std = torch.std(torch.stack(landmarks_arr, dim=0)) #, dim=(0, 2, 3))print(mean.shape)print("mean is {} and std is {}".format(mean, std))结果:torch.Size([])mean is nan and std is nan上面有几个问题:为什么 to_tensor 不转换 0 和 1 之间的值?如何正确计算平均值?我应该除以 255 吗?我有:len(landmarks_arr)    720和landmarks_arr[0].shapetorch.Size([1, 4, 2])和landmarks_arr[0]tensor([[[502.2869, 240.4949],         [688.0000, 293.0000],         [346.0000, 317.0000],         [560.8283, 322.6830]]], dtype=torch.float64)
查看完整描述

1 回答

?
aluckdog

TA贡献1847条经验 获得超7个赞

  1. 来自 ToTensor() 的 pytorch 文档:

如果 PIL 图像属于,则将 [0, 255] 范围内的 PIL 图像或 numpy.ndarray (H x W x C) 转换为 [0.0, 1.0] 范围内形状 (C x H x W) 的 torch.FloatTensor模式之一(L、LA、P、I、F、RGB、YCbCr、RGBA、CMYK、1)或者 numpy.ndarray 的 dtype = np.uint8

在其他情况下,返回张量而不进行缩放。

由于您的 Landmark 值不是 PIL 图像,并且不在 [0, 255] 范围内,因此不会应用缩放。

  1. 您的计算看起来是正确的。看起来,您的数据中可能有一些 NaN 值。

你可以尝试类似的东西

for i in range(len(train_dataset)):
    landmarks = to_tensor(train_dataset[i]['landmarks'])
    landmarks[landmarks != landmarks] = 0  # this will set all nan to zero
    landmarks_arr.append(landmarks)

在你的循环内。或者在循环中断言 for nan 以找到罪魁祸首:

for i in range(len(train_dataset)):
    landmarks = to_tensor(train_dataset[i]['landmarks'])    assert(not torch.isnan(landmarks).any()), f'nan encountered in sample {i}'  # will trigger if a landmark contains nan
    landmarks_arr.append(landmarks)
  1. 不,请参见 1)。如果您愿意,您可以除以地标的最大坐标,将它们限制为 [0, 1]。

  2. https://img1.sycdn.imooc.com/654cea2700015a5206330418.jpg

查看完整回答
反对 回复 2023-11-09
  • 1 回答
  • 0 关注
  • 123 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信