为了账号安全,请及时绑定邮箱和手机立即绑定

Keras 生成器和 fit_generator,如何构建生成器以避免“函数形状”错误

Keras 生成器和 fit_generator,如何构建生成器以避免“函数形状”错误

繁星淼淼 2021-12-21 16:22:38
我正在为 Keras 构建一个生成器,以便能够加载我的数据集图像,因为它对我的 ram 来说有点大。我像这样构建了生成器:# import the necessary packagesimport tensorflowfrom tensorflow import kerasfrom keras.preprocessing.image import ImageDataGeneratorimport matplotlib.pyplot as pltfrom sklearn.preprocessing import OneHotEncoderimport numpy as npimport pandas as pdfrom tqdm import tqdm#loadingpath_to_txt = "/content/test/leafsnap-dataset/leafsnap-dataset- images_improved.txt"df = pd.read_csv(path_to_txt ,sep='\t')arr = np.array(df)#epochs and steps:NUM_TRAIN_IMAGES = 0NUM_EPOCHS = 30def image_generator(arr, bs, mode="train", aug=None):  while True:    images = []    labels = []    for row in arr:      if len(images) < bs:        img = (cv2.resize(cv2.imread("/content/test/leafsnap-dataset/" +         row[0]),(224,224)))        images.append(img)        labels.append([row[2]])        NUM_TRAIN_IMAGES += 1      else:        break  if aug is not None:    (images, labels) = next(aug.flow(np.array(images),labels,      batch_size=bs))  obj = OneHotEncoder()  values = obj.fit_transform(labels).toarray()  yield (np.array(images), labels)然后我从顺序模型中调用 fit_generator (cnn 一直工作,直到出现 OOM 错误)#create the augmentation function: aug = ImageDataGenerator(rotation_range=20, zoom_range=0.15,    width_shift_range=0.2, height_shift_range=0.2, shear_range=0.15,    horizontal_flip=True, fill_mode="nearest")#create the generator:gen = image_generator(arr, bs = 32, mode = "train", aug = aug)history = model.fit_generator(image_generator,    steps_per_epoch = NUM_TRAIN_IMAGES,    epochs = NUM_EPOCHS)从这里,我收到此错误:# Create generator from NumPy or EagerTensor Input.--> 377   num_samples = int(nest.flatten(data)[0].shape[0])378   if batch_size is None:379     raise ValueError('You must specify `batch_size`')AttributeError: 'function' object has no attribute 'shape'
查看完整描述

1 回答

?
慕森王

TA贡献1777条经验 获得超3个赞

我在这里看到两个主要错误。

首先,您的生成器函数的内存效率不高。因为您首先加载所有图像(while 循环)。您应该遍历图像文件并在循环内产生带有标签的图像的 np.array。

其次,当您应该使用其返回的对象 - gen 时,您将生成器函数名称传递给 fit_generator。



查看完整回答
反对 回复 2021-12-21
  • 1 回答
  • 0 关注
  • 191 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信