我正在制作一个DataLoaderfrom DataSetin PyTorch。从加载DataFrame所有dtype作为一个np.float64result = pd.read_csv('dummy.csv', header=0, dtype=DTYPE_CLEANED_DF)这是我的数据集类。from torch.utils.data import Dataset, DataLoaderclass MyDataset(Dataset): def __init__(self, result): headers = list(result) headers.remove('classes') self.x_data = result[headers] self.y_data = result['classes'] self.len = self.x_data.shape[0] def __getitem__(self, index): x = torch.tensor(self.x_data.iloc[index].values, dtype=torch.float) y = torch.tensor(self.y_data.iloc[index], dtype=torch.float) return (x, y) def __len__(self): return self.len准备 train_loader and test_loadertrain_size = int(0.5 * len(full_dataset))test_size = len(full_dataset) - train_sizetrain_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])train_loader = DataLoader(dataset=train_dataset, batch_size=16, shuffle=True, num_workers=1)test_loader = DataLoader(dataset=train_dataset)这是我的csv 文件如何解决pandas这里的问题?
3 回答
守着星空守着你
TA贡献1799条经验 获得超8个赞
我总共有 2298 张图片。所以如果我按照以下方式做
[int(len(data)*0.8),int(len(data)*0.2)]
它抛出有问题的错误。作为
[int(len(data)*0.8)+int(len(data)*0.2)]=2297
所以我做的是floor
和ceil
功能
[int(np.floor(len(data)*0.8)),int(np.ceil(len(data)*0.2))])
结果是 2298 并且错误消失了
添加回答
举报
0/150
提交
取消