为了账号安全,请及时绑定邮箱和手机立即绑定

带有 Tensorflow 数据集 API 的 Keras 生成器 - IndexError:

带有 Tensorflow 数据集 API 的 Keras 生成器 - IndexError:

慕无忌1623718 2021-06-03 23:22:54
我需要开发一个 RNN 模型,并希望使用数据生成器来提供训练/评估循环。首先,我在从 csv 文件中获取数据时使用了这个帮助功能。RECORD_DEFAULTS_TRAIN = [[0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0]]def decode_csv(line):   parsed_line = tf.decode_csv(line, RECORD_DEFAULTS_TRAIN)   label =  parsed_line[-1]      # label is the last element of the list   del parsed_line[-1]           # delete the last element from the list   del parsed_line[0]            # even delete the first element bcz it is assumed NOT to be a feature   features = tf.stack(parsed_line)  # Stack features so that you can later vectorize forward prop., etc.   return features, label 这是我的数据生成器功能:def data_generator(file_path_list, batch_size):  filenames = tf.placeholder(tf.string, shape=[None])  dataset = tf.data.Dataset.from_tensor_slices(filenames)  dataset = dataset.flat_map(lambda filename: tf.data.TextLineDataset(filename).skip(1).map(decode_csv))  dataset = dataset.shuffle(buffer_size=1000)  dataset = dataset.batch(batch_size)  iterator = dataset.make_initializable_iterator()  next_element = iterator.get_next()  with tf.Session() as sess:      while True:          sess.run(iterator.initializer, feed_dict={filenames: file_path_list})          while True:                          try:                batch_data, batch_labels = sess.run(next_element)                # Dimension of the data needs to be: (batch_size, length_of_each_sequence, nr_inputs_in_each_timestep)                # Since the last batch in a epoch can have a different size,                # "batch_data.shape[0]" is used instead of batch_size                batch_data = np.reshape(batch_data, (batch_data.shape[0], SEQUENCE_LEN, 1))              except tf.errors.OutOfRangeError:                break              yield (batch_data, batch_labels)
查看完整描述

1 回答

?
冉冉说

TA贡献1877条经验 获得超1个赞

解决了。我想解释这个问题而不是删除我的帖子,以便它也可以帮助其他人。


我只会给出evaluate_generator(...)函数的例子。这就是我调用函数的方式..


lstm_model.evaluate_generator(data_generator(TEST_FILE_PATHS, TEST_BATCH_SIZE), 

                             steps=(NR_TEST_EXAMPLES // TEST_BATCH_SIZE), 

                             verbose=1)

我将其更改如下:


test_data_generator = data_generator(TEST_FILE_PATHS, TEST_BATCH_SIZE)

lstm_model.evaluate_generator(test_data_generator, 

                              steps=(NR_TEST_EXAMPLES // TEST_BATCH_SIZE), 

                              verbose=1)

问题解决了。我在不同的地方看到了这两种用法,即使人们在网上找到的每一种信息都不一定是真的。我也不清楚为什么在更改上面的代码时可以解决它。如果有人知道,我会很高兴听到解释。


查看完整回答
反对 回复 2021-06-16
  • 1 回答
  • 0 关注
  • 186 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信