为了账号安全,请及时绑定邮箱和手机立即绑定

如何使用 tf.keras.utils.Sequence API 扩充训练集?

如何使用 tf.keras.utils.Sequence API 扩充训练集?

喵喔喔 2023-07-27 16:31:06
TensorFlow 文档有以下示例,可以说明当训练集太大而无法放入内存时,如何创建批量生成器以将训练集批量提供给模型:from skimage.io import imreadfrom skimage.transform import resizeimport tensorflow as tfimport numpy as npimport math# Here, `x_set` is list of path to the images# and `y_set` are the associated classes.class CIFAR10Sequence(tf.keras.utils.Sequence):    def __init__(self, x_set, y_set, batch_size):        self.x, self.y = x_set, y_set        self.batch_size = batch_size    def __len__(self):        return math.ceil(len(self.x) / self.batch_size)    def __getitem__(self, idx):        batch_x = self.x[idx * self.batch_size:(idx + 1) *        self.batch_size]        batch_y = self.y[idx * self.batch_size:(idx + 1) *        self.batch_size]        return np.array([            resize(imread(file_name), (200, 200))               for file_name in batch_x]), np.array(batch_y)我的目的是通过将每个图像旋转 3 倍 90° 来进一步增加训练集的多样性。在训练过程的每个 Epoch 中,模型将首先输入“0° 训练集”,然后分别输入 90°、180° 和 270° 旋转集。如何修改前面的代码以在CIFAR10Sequence()数据生成器中执行此操作?请不要使用tf.keras.preprocessing.image.ImageDataGenerator(),以免答案失去对其他类型不同性质的类似问题的普遍性。注意:这个想法是在模型被输入时“实时”创建新数据,而不是(提前)创建并在磁盘上存储一个比稍后使用的原始训练集更大的新的增强训练集(也在批次)在模型的训练过程中。
查看完整描述

1 回答

?
米脂

TA贡献1836条经验 获得超3个赞

使用自定义Callback并挂钩到on_epoch_end. 每个纪元结束后更改数据迭代器对象的角度。


示例(内联记录)

from skimage.io import imread

from skimage.transform import resize, rotate

import numpy as np


import tensorflow as tf

from tensorflow import keras

from tensorflow.keras import layers

from keras.utils import Sequence 

from keras.models import Sequential

from keras.layers import Conv2D, Activation, Flatten, Dense


# Model architecture  (dummy)

model = Sequential()

model.add(Conv2D(32, (3, 3), input_shape=(15, 15, 4)))

model.add(Activation('relu'))

model.add(Flatten())

model.add(Dense(1))

model.add(Activation('sigmoid'))

model.compile(loss='binary_crossentropy',

              optimizer='rmsprop',

              metrics=['accuracy'])


# Data iterator 

class CIFAR10Sequence(Sequence):

    def __init__(self, filenames, labels, batch_size):

        self.filenames, self.labels = filenames, labels

        self.batch_size = batch_size

        self.angles = [0,90,180,270]

        self.current_angle_idx = 0


    # Method to loop throught the available angles

    def change_angle(self):

      self.current_angle_idx += 1

      if self.current_angle_idx >= len(self.angles):

        self.current_angle_idx = 0

  

    def __len__(self):

        return int(np.ceil(len(self.filenames) / float(self.batch_size)))


    # read, resize and rotate the image and return a batch of images

    def __getitem__(self, idx):

        angle = self.angles[self.current_angle_idx]

        print (f"Rotating Angle: {angle}")


        batch_x = self.filenames[idx * self.batch_size:(idx + 1) * self.batch_size]

        batch_y = self.labels[idx * self.batch_size:(idx + 1) * self.batch_size]

        return np.array([

            rotate(resize(imread(filename), (15, 15)), angle)

               for filename in batch_x]), np.array(batch_y)


# Custom call back to hook into on epoch end

class CustomCallback(keras.callbacks.Callback):

    def __init__(self, sequence):

      self.sequence = sequence


    # after end of each epoch change the rotation for next epoch

    def on_epoch_end(self, epoch, logs=None):

      self.sequence.change_angle()               



# Create data reader

sequence = CIFAR10Sequence(["f1.PNG"]*10, [0, 1]*5, 8)

# fit the model and hook in the custom call back

model.fit(sequence, epochs=10, callbacks=[CustomCallback(sequence)])

输出:


Rotating Angle: 0

Epoch 1/10

Rotating Angle: 0

Rotating Angle: 0

2/2 [==============================] - 2s 755ms/step - loss: 1.0153 - accuracy: 0.5000

Epoch 2/10

Rotating Angle: 90

Rotating Angle: 90

2/2 [==============================] - 0s 190ms/step - loss: 0.6975 - accuracy: 0.5000

Epoch 3/10

Rotating Angle: 180

Rotating Angle: 180

2/2 [==============================] - 2s 772ms/step - loss: 0.6931 - accuracy: 0.5000

Epoch 4/10

Rotating Angle: 270

Rotating Angle: 270

2/2 [==============================] - 0s 197ms/step - loss: 0.6931 - accuracy: 0.5000

Epoch 5/10

Rotating Angle: 0

Rotating Angle: 0

2/2 [==============================] - 0s 189ms/step - loss: 0.6931 - accuracy: 0.5000

Epoch 6/10

Rotating Angle: 90

Rotating Angle: 90

2/2 [==============================] - 2s 757ms/step - loss: 0.6932 - accuracy: 0.5000

Epoch 7/10

Rotating Angle: 180

Rotating Angle: 180

2/2 [==============================] - 2s 757ms/step - loss: 0.6931 - accuracy: 0.5000

Epoch 8/10

Rotating Angle: 270

Rotating Angle: 270

2/2 [==============================] - 2s 761ms/step - loss: 0.6932 - accuracy: 0.5000

Epoch 9/10

Rotating Angle: 0

Rotating Angle: 0

2/2 [==============================] - 1s 744ms/step - loss: 0.6932 - accuracy: 0.5000

Epoch 10/10

Rotating Angle: 90

Rotating Angle: 90

2/2 [==============================] - 0s 192ms/step - loss: 0.6931 - accuracy: 0.5000

<tensorflow.python.keras.callbacks.History at 0x7fcbdf8bcdd8>



查看完整回答
反对 回复 2023-07-27
  • 1 回答
  • 0 关注
  • 115 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信