Keras的fit_generator()模型方法期望生成器生成形状(输入,目标)的元组,其中两个元素都是NumPy数组。该文档似乎暗示着,如果我将Dataset迭代器简单地包装在生成器中,并确保将Tensors转换为NumPy数组,那我应该很好。这段代码给我一个错误:import numpy as npimport osimport keras.backend as Kfrom keras.layers import Dense, Inputfrom keras.models import Modelimport tensorflow as tffrom tensorflow.contrib.data import Datasetos.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'with tf.Session() as sess: def create_data_generator(): dat1 = np.arange(4).reshape(-1, 1) ds1 = Dataset.from_tensor_slices(dat1).repeat() dat2 = np.arange(5, 9).reshape(-1, 1) ds2 = Dataset.from_tensor_slices(dat2).repeat() ds = Dataset.zip((ds1, ds2)).batch(4) iterator = ds.make_one_shot_iterator() while True: next_val = iterator.get_next() yield sess.run(next_val)datagen = create_data_generator()input_vals = Input(shape=(1,))output = Dense(1, activation='relu')(input_vals)model = Model(inputs=input_vals, outputs=output)model.compile('rmsprop', 'mean_squared_error')model.fit_generator(datagen, steps_per_epoch=1, epochs=5, verbose=2, max_queue_size=2)这是我得到的错误:Using TensorFlow backend.Epoch 1/5Exception in thread Thread-1:Traceback (most recent call last): File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 270, in __init__ fetch, allow_tensor=True, allow_operation=True)) File "/home/jsaporta/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2708, in as_graph_element return self._as_graph_element_locked(obj, allow_tensor, allow_operation)奇怪的是,next(datagen)在我初始化的位置之后直接添加包含一行datagen的代码会使代码运行正常,没有错误。为什么我的原始代码不起作用?将行添加到代码中后,为什么它开始起作用?是否有一种更有效的方式将TensorFlow的Dataset API与Keras结合使用,而无需将Tensors转换为NumPy数组然后再次返回?
添加回答
举报
0/150
提交
取消