我在 python/Keras 中创建了一个数据生成器,以在 batchesize=5 中提取文件名和标签。每次迭代都会获得相同的文件名和标签。我希望每次迭代都能获得新的(成功的)文件名和标签。我查看了许多示例并阅读了文档,但无法弄清楚。def datagenerator(imgfns, imglabels, batchsize, mode="train"): while True: images = [] labels = [] cnt=0 while len(images) < batchsize: images.append(imgfns[cnt]) labels.append(imglabels[cnt]) cnt=cnt+1 #for ii in range(batchsize): # #img = np.load(imgfns[ii]) # #images.append(img) # images.append(imgfns[ii]) # labels.append(imglabels[ii]) #for image, label in zip(imgfns, imglabels): # #img = np.load(image) # #images.append(img) # images.append(image) # labels.append(label) print(images) print(labels) print('********** cnt = ', cnt) yield images, labelstrain_gen = datagenerator(train_uxo_scrap, train_uxo_scrap_labels, batchsize=BS)valid_gen = datagenerator(test_uxo_scrap, test_uxo_scrap_labels, batchsize=BS)# train the networkH = model.fit_generator( train_gen, steps_per_epoch=NUM_TRAIN_IMAGES // BS, validation_data=valid_gen, validation_steps=NUM_TEST_IMAGES // BS, epochs=NUM_EPOCHS)这是我得到的输出示例。您可以看到,每次它通过生成器时,它都会获取相同的数据。“Epoch 1/10”之后的第一行,有5个文件名。下一行有5个标签(对应batchsize=5)。例如,您可以在每个输出中看到第一个文件名是“... 508.npy”等。并且每次迭代的标签都相同。
1 回答
牛魔王的故事
TA贡献1830条经验 获得超3个赞
问题是您正在设置cnt=0每次迭代。您获取 5 个文件名,生成它们,然后重复相同的事情,因此您总是获取前 5 个。您想要更改
def datagenerator(imgfns, imglabels, batchsize, mode="train"):
while True:
images = []
labels = []
cnt=0
到
def datagenerator(imgfns, imglabels, batchsize, mode="train"):
cnt=0
while True:
images = []
labels = []
您还需要确保cnt保持在列表的限制范围内。所以像
while len(images) < batchsize and cnt < len(imgfns):
# blah
添加回答
举报
0/150
提交
取消