2 回答
TA贡献1895条经验 获得超7个赞
在 kerasEarlyStopping
中,当受监控的数量停止改善时,您会停止。从您的问题来看,您不清楚要停止的条件是什么。如果您只想监视一个值,EarlyStopping
但只想在一批后停止,如果该值没有提高,您可以重写EarlyStopping
类并实现逻辑 inon_batch_end
而不是on_epoch_end
:
class EarlyBatchStopping(Callback):
def __init__(self,
monitor='val_loss',
min_delta=0,
patience=0,
verbose=0,
mode='auto',
baseline=None,
restore_best_weights=False):
super(EarlyStopping, self).__init__()
self.monitor = monitor
self.baseline = baseline
self.patience = patience
self.verbose = verbose
self.min_delta = min_delta
self.wait = 0
self.stopped_epoch = 0
self.restore_best_weights = restore_best_weights
self.best_weights = None
if mode not in ['auto', 'min', 'max']:
warnings.warn('EarlyStopping mode %s is unknown, '
'fallback to auto mode.' % mode,
RuntimeWarning)
mode = 'auto'
if mode == 'min':
self.monitor_op = np.less
elif mode == 'max':
self.monitor_op = np.greater
else:
if 'acc' in self.monitor:
self.monitor_op = np.greater
else:
self.monitor_op = np.less
if self.monitor_op == np.greater:
self.min_delta *= 1
else:
self.min_delta *= -1
def on_train_begin(self, logs=None):
# Allow instances to be re-used
self.wait = 0
self.stopped_epoch = 0
if self.baseline is not None:
self.best = self.baseline
else:
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
def on_batch_end(self, epoch, logs=None):
current = self.get_monitor_value(logs)
if current is None:
return
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
if self.restore_best_weights:
self.best_weights = self.model.get_weights()
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
if self.restore_best_weights:
if self.verbose > 0:
print('Restoring model weights from the end of '
'the best epoch')
self.model.set_weights(self.best_weights)
def on_train_end(self, logs=None):
if self.stopped_epoch > 0 and self.verbose > 0:
print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))
def get_monitor_value(self, logs):
monitor_value = logs.get(self.monitor)
if monitor_value is None:
warnings.warn(
'Early stopping conditioned on metric `%s` '
'which is not available. Available metrics are: %s' %
(self.monitor, ','.join(list(logs.keys()))), RuntimeWarning
)
return monitor_value
如果您有其他逻辑,则可以根据您的逻辑使用on_batch_end和设置,但我认为您明白了。self.model.stop_training = True
TA贡献1842条经验 获得超12个赞
您可以使用model.stop_training参数来停止训练。
例如,如果我们想在第 2 轮第 3 批次停止训练,那么您可以执行如下操作。
import keras
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
import numpy as np
import pandas as pd
class My_Callback(keras.callbacks.Callback):
def on_epoch_begin(self, epoch, logs={}):
self.epoch = epoch
def on_batch_end(self, batch, logs={}):
if self.epoch == 1 and batch == 3:
print (f"\nStopping at Epoch {self.epoch}, Batch {batch}")
self.model.stop_training = True
X_train = np.random.random((100, 3))
y_train = pd.get_dummies(np.argmax(X_train[:, :3], axis=1)).values
clf = Sequential()
clf.add(Dense(9, activation='relu', input_dim=3))
clf.add(Dense(3, activation='softmax'))
clf.compile(loss='categorical_crossentropy', optimizer=SGD())
clf.fit(X_train, y_train, epochs=10, batch_size=16, callbacks=[My_Callback()])
输出:
Epoch 1/10
100/100 [==============================] - 0s 337us/step - loss: 1.0860
Epoch 2/10
16/100 [===>..........................] - ETA: 0s - loss: 1.0830
Stopping at Epoch 1, Batch 3
<keras.callbacks.callbacks.History at 0x7ff2e3eeee10>
添加回答
举报