1 回答
![?](http://img1.sycdn.imooc.com/545864490001b5bd02200220-100-100.jpg)
TA贡献1836条经验 获得超3个赞
使用自定义Callback并挂钩到on_epoch_end. 每个纪元结束后更改数据迭代器对象的角度。
示例(内联记录)
from skimage.io import imread
from skimage.transform import resize, rotate
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from keras.utils import Sequence
from keras.models import Sequential
from keras.layers import Conv2D, Activation, Flatten, Dense
# Model architecture (dummy)
model = Sequential()
model.add(Conv2D(32, (3, 3), input_shape=(15, 15, 4)))
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dense(1))
model.add(Activation('sigmoid'))
model.compile(loss='binary_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
# Data iterator
class CIFAR10Sequence(Sequence):
def __init__(self, filenames, labels, batch_size):
self.filenames, self.labels = filenames, labels
self.batch_size = batch_size
self.angles = [0,90,180,270]
self.current_angle_idx = 0
# Method to loop throught the available angles
def change_angle(self):
self.current_angle_idx += 1
if self.current_angle_idx >= len(self.angles):
self.current_angle_idx = 0
def __len__(self):
return int(np.ceil(len(self.filenames) / float(self.batch_size)))
# read, resize and rotate the image and return a batch of images
def __getitem__(self, idx):
angle = self.angles[self.current_angle_idx]
print (f"Rotating Angle: {angle}")
batch_x = self.filenames[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.labels[idx * self.batch_size:(idx + 1) * self.batch_size]
return np.array([
rotate(resize(imread(filename), (15, 15)), angle)
for filename in batch_x]), np.array(batch_y)
# Custom call back to hook into on epoch end
class CustomCallback(keras.callbacks.Callback):
def __init__(self, sequence):
self.sequence = sequence
# after end of each epoch change the rotation for next epoch
def on_epoch_end(self, epoch, logs=None):
self.sequence.change_angle()
# Create data reader
sequence = CIFAR10Sequence(["f1.PNG"]*10, [0, 1]*5, 8)
# fit the model and hook in the custom call back
model.fit(sequence, epochs=10, callbacks=[CustomCallback(sequence)])
输出:
Rotating Angle: 0
Epoch 1/10
Rotating Angle: 0
Rotating Angle: 0
2/2 [==============================] - 2s 755ms/step - loss: 1.0153 - accuracy: 0.5000
Epoch 2/10
Rotating Angle: 90
Rotating Angle: 90
2/2 [==============================] - 0s 190ms/step - loss: 0.6975 - accuracy: 0.5000
Epoch 3/10
Rotating Angle: 180
Rotating Angle: 180
2/2 [==============================] - 2s 772ms/step - loss: 0.6931 - accuracy: 0.5000
Epoch 4/10
Rotating Angle: 270
Rotating Angle: 270
2/2 [==============================] - 0s 197ms/step - loss: 0.6931 - accuracy: 0.5000
Epoch 5/10
Rotating Angle: 0
Rotating Angle: 0
2/2 [==============================] - 0s 189ms/step - loss: 0.6931 - accuracy: 0.5000
Epoch 6/10
Rotating Angle: 90
Rotating Angle: 90
2/2 [==============================] - 2s 757ms/step - loss: 0.6932 - accuracy: 0.5000
Epoch 7/10
Rotating Angle: 180
Rotating Angle: 180
2/2 [==============================] - 2s 757ms/step - loss: 0.6931 - accuracy: 0.5000
Epoch 8/10
Rotating Angle: 270
Rotating Angle: 270
2/2 [==============================] - 2s 761ms/step - loss: 0.6932 - accuracy: 0.5000
Epoch 9/10
Rotating Angle: 0
Rotating Angle: 0
2/2 [==============================] - 1s 744ms/step - loss: 0.6932 - accuracy: 0.5000
Epoch 10/10
Rotating Angle: 90
Rotating Angle: 90
2/2 [==============================] - 0s 192ms/step - loss: 0.6931 - accuracy: 0.5000
<tensorflow.python.keras.callbacks.History at 0x7fcbdf8bcdd8>
添加回答
举报