为了账号安全,请及时绑定邮箱和手机立即绑定

类型错误:“numpy.int64”类型的对象没有 len()

类型错误:“numpy.int64”类型的对象没有 len()

蛊毒传说 2021-09-24 15:43:55
我正在制作一个DataLoaderfrom DataSetin PyTorch。从加载DataFrame所有dtype作为一个np.float64result = pd.read_csv('dummy.csv', header=0, dtype=DTYPE_CLEANED_DF)这是我的数据集类。from torch.utils.data import Dataset, DataLoaderclass MyDataset(Dataset):    def __init__(self, result):        headers = list(result)        headers.remove('classes')        self.x_data = result[headers]        self.y_data = result['classes']        self.len = self.x_data.shape[0]    def __getitem__(self, index):        x = torch.tensor(self.x_data.iloc[index].values, dtype=torch.float)        y = torch.tensor(self.y_data.iloc[index], dtype=torch.float)        return (x, y)    def __len__(self):        return self.len准备 train_loader and test_loadertrain_size = int(0.5 * len(full_dataset))test_size = len(full_dataset) - train_sizetrain_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])train_loader = DataLoader(dataset=train_dataset, batch_size=16, shuffle=True, num_workers=1)test_loader = DataLoader(dataset=train_dataset)这是我的csv 文件如何解决pandas这里的问题?
查看完整描述

3 回答

?
守着星空守着你

TA贡献1799条经验 获得超8个赞

我总共有 2298 张图片。所以如果我按照以下方式做

[int(len(data)*0.8),int(len(data)*0.2)]

它抛出有问题的错误。作为

[int(len(data)*0.8)+int(len(data)*0.2)]=2297

所以我做的是floorceil功能

[int(np.floor(len(data)*0.8)),int(np.ceil(len(data)*0.2))])

结果是 2298 并且错误消失了


查看完整回答
反对 回复 2021-09-24
  • 3 回答
  • 0 关注
  • 406 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信