为了账号安全,请及时绑定邮箱和手机立即绑定

使用数据集生成器的 Tensorflow model.fit()

使用数据集生成器的 Tensorflow model.fit()

慕勒3428872 2021-12-17 16:53:16
我正在使用数据集 API 生成训练数据并将其分类为 NN 的批次。这是我的代码的最小工作示例:import tensorflow as tfimport numpy as npimport randomdef my_generator():    while True:        x = np.random.rand(4, 20)        y = random.randint(0, 11)        label = tf.one_hot(y, depth=12)        yield x.reshape(4, 20, 1), labeldef my_input_fn():    dataset = tf.data.Dataset.from_generator(lambda: my_generator(),                                             output_types=(tf.float64, tf.int32))    dataset = dataset.batch(32)    iterator = dataset.make_one_shot_iterator()    batch_features, batch_labels = iterator.get_next()    return batch_features, batch_labelsif __name__ == "__main__":    tf.enable_eager_execution()    model = tf.keras.Sequential([tf.keras.layers.Flatten(input_shape=(4, 20, 1)),                                 tf.keras.layers.Dense(128, activation=tf.nn.relu),                                 tf.keras.layers.Dense(12, activation=tf.nn.softmax)])    model.compile(optimizer='adam',                  loss='categorical_crossentropy',                  metrics=['accuracy'])    data_generator = my_input_fn()    model.fit(data_generator)但batch_size不是 的公认关键字fit_generator()。我对这些错误消息感到困惑,如果有人能对它们有所了解,或者指出我做错了什么,我将不胜感激。
查看完整描述

1 回答

?
FFIVE

TA贡献1797条经验 获得超6个赞

虽然错误的根源仍然模糊不清,但我已经找到了使代码工作的解决方案。我会把它贴在这里,以防它对处于类似情况的任何人有用。


基本上,我将其更改my_input_fn()为生成器并使用model.fit_generator()如下:


import tensorflow as tf

import numpy as np

import random



def my_generator(total_items):

    i = 0

    while i < total_items:

        x = np.random.rand(4, 20)

        y = random.randint(0, 11)

        label = tf.one_hot(y, depth=12)

        yield x.reshape(4, 20, 1), label

        i += 1


def my_input_fn(total_items, epochs):

    dataset = tf.data.Dataset.from_generator(lambda: my_generator(total_items),

                                             output_types=(tf.float64, tf.int64))


    dataset = dataset.repeat(epochs)

    dataset = dataset.batch(32)



    iterator = dataset.make_one_shot_iterator()

    while True:

        batch_features, batch_labels = iterator.get_next()

        yield batch_features, batch_labels


if __name__ == "__main__":

    tf.enable_eager_execution()


    model = tf.keras.Sequential([tf.keras.layers.Flatten(input_shape=(4, 20, 1)),

                                 tf.keras.layers.Dense(64, activation=tf.nn.relu),

                                 tf.keras.layers.Dense(12, activation=tf.nn.softmax)])


    model.compile(optimizer='adam',

                  loss='categorical_crossentropy',

                  metrics=['accuracy'])


    total_items = 200

    batch_size = 32

    epochs = 10

    num_batches = int(total_items/batch_size)

    train_data_generator = my_input_fn(total_items, epochs)

    model.fit_generator(generator=train_data_generator, steps_per_epoch=num_batches, epochs=epochs, verbose=1)

编辑


正如 giser_yugang 在评论中暗示的那样,也可以将其my_input_fn()作为返回dataset而不是单个批次的函数来执行。


def my_input_fn(total_items, epochs):

    dataset = tf.data.Dataset.from_generator(lambda: my_generator(total_items),

                                             output_types=(tf.float64, tf.int64))


    dataset = dataset.repeat(epochs)

    dataset = dataset.batch(32)

    return dataset


if __name__ == "__main__":

    tf.enable_eager_execution()


    model = tf.keras.Sequential([tf.keras.layers.Flatten(input_shape=(4, 20, 1)),

                                 tf.keras.layers.Dense(64, activation=tf.nn.relu),

                                 tf.keras.layers.Dense(12, activation=tf.nn.softmax)])


    model.compile(optimizer='adam',

                  loss='categorical_crossentropy',

                  metrics=['accuracy'])


    total_items = 100

    batch_size = 32

    epochs = 10

    num_batches = int(total_items/batch_size)

    dataset = my_input_fn(total_items, epochs)

    model.fit_generator(dataset, epochs=epochs, steps_per_epoch=num_batches)

这些方法之间似乎没有任何平均性能差异。


查看完整回答
反对 回复 2021-12-17
  • 1 回答
  • 0 关注
  • 251 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信