为了账号安全,请及时绑定邮箱和手机立即绑定

从现有 Torchvision 数据集创建缩减数据集

从现有 Torchvision 数据集创建缩减数据集

饮歌长啸 2021-10-26 15:32:09
我们都知道常见的 MNIST 数据集,包含在torchvision.datasets包中。想象一下,我想创建一个仅包含1和0 的数据集的简化版本,以仅对这两个数字进行分类,而不是对所有 10 个值进行分类。我已经看到可以在继承所需数据集的类中创建自定义数据集,所以__getitem__,它返回给定索引处的项目。所以我这样做了:class MNIST01(MNIST):    def __getitem__(self, idx):        image, label = super().__getitem__(idx)        if label.item() <= 1:            return image, label        else:            return None问题是,我似乎无法返回 None 值,因为它必须是“包含张量、数字、字典或列表;找到类“NoneType””。有没有一种简单的方法可以以类似的方式轻松获得此数据集的简化版本?
查看完整描述

1 回答

?
缥缈止盈

TA贡献2041条经验 获得超4个赞

我终于设法处理了 NoneType 问题。保留问题中定义的函数。


class MNIST01(MNIST):

    def __getitem__(self, idx):

        features, target = super(MNIST01, self).__getitem__(idx)

        if target.item() <= 1:

            return features, target

我们现在需要为我们的数据加载器定义一个自定义整理函数 collate_fn,它处理样本列表以形成一个批次。在这个函数中,我们可以应用过滤器来处理None值并忽略它们。


from torch.utils.data.dataloader import default_collate


def filter_collate(batch):

    batch = list(filter(lambda x: x is not None, batch))

    return default_collate(batch)

然后我们只需要将这个函数传递给DataLoader:


from torch.utils.data import DataLoader


train_loader = DataLoader(train_dataset, collate_fn=filter_collate, **kwargs)

test_loader = DataLoader(test_dataset, collate_fn=filter_collate, **kwargs)

版本 2


比第一个容易多了,避免了访问数据时的一些问题。只需从类的实例中直接过滤train_data和train_label属性(以及对应于测试集)MNIST。


train_dataset.train_data = train_dataset.train_data[train_dataset.train_labels <= 1]

train_dataset.train_labels = train_dataset.train_labels[train_dataset.train_labels <= 1]


查看完整回答
反对 回复 2021-10-26
  • 1 回答
  • 0 关注
  • 195 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
微信客服

购课补贴
联系客服咨询优惠详情

帮助反馈 APP下载

慕课网APP
您的移动学习伙伴

公众号

扫描二维码
关注慕课网微信公众号