为了账号安全,请及时绑定邮箱和手机立即绑定

Keras有没有办法立即停止训练?

Keras有没有办法立即停止训练?

MYYA 2022-12-20 16:25:18
我正在为我的tf.keras训练编写自定义提前停止回调。为此,我可以self.model.stop_training = True在其中一个回调函数中设置变量,例如on_epoch_end(). 然而,Keras 仅在当前时期完成时才停止训练,即使我在一个时期的训练中设置了这个变量,例如在on_batch_end().因此我的问题是:Keras 有没有办法立即停止训练,即使是在当前时代的进展中?
查看完整描述

2 回答

?
人到中年有点甜

TA贡献1895条经验 获得超7个赞

在 kerasEarlyStopping中,当受监控的数量停止改善时,您会停止。从您的问题来看,您不清楚要停止的条件是什么。如果您只想监视一个值,EarlyStopping但只想在一批后停止,如果该值没有提高,您可以重写EarlyStopping类并实现逻辑 inon_batch_end而不是on_epoch_end


class EarlyBatchStopping(Callback):



    def __init__(self,

                 monitor='val_loss',

                 min_delta=0,

                 patience=0,

                 verbose=0,

                 mode='auto',

                 baseline=None,

                 restore_best_weights=False):

        super(EarlyStopping, self).__init__()


        self.monitor = monitor

        self.baseline = baseline

        self.patience = patience

        self.verbose = verbose

        self.min_delta = min_delta

        self.wait = 0

        self.stopped_epoch = 0

        self.restore_best_weights = restore_best_weights

        self.best_weights = None


        if mode not in ['auto', 'min', 'max']:

            warnings.warn('EarlyStopping mode %s is unknown, '

                          'fallback to auto mode.' % mode,

                          RuntimeWarning)

            mode = 'auto'


        if mode == 'min':

            self.monitor_op = np.less

        elif mode == 'max':

            self.monitor_op = np.greater

        else:

            if 'acc' in self.monitor:

                self.monitor_op = np.greater

            else:

                self.monitor_op = np.less


        if self.monitor_op == np.greater:

            self.min_delta *= 1

        else:

            self.min_delta *= -1


    def on_train_begin(self, logs=None):

        # Allow instances to be re-used

        self.wait = 0

        self.stopped_epoch = 0

        if self.baseline is not None:

            self.best = self.baseline

        else:

            self.best = np.Inf if self.monitor_op == np.less else -np.Inf


    def on_batch_end(self, epoch, logs=None):

        current = self.get_monitor_value(logs)

        if current is None:

            return


        if self.monitor_op(current - self.min_delta, self.best):

            self.best = current

            self.wait = 0

            if self.restore_best_weights:

                self.best_weights = self.model.get_weights()

        else:

            self.wait += 1

            if self.wait >= self.patience:

                self.stopped_epoch = epoch

                self.model.stop_training = True

                if self.restore_best_weights:

                    if self.verbose > 0:

                        print('Restoring model weights from the end of '

                              'the best epoch')

                    self.model.set_weights(self.best_weights)


    def on_train_end(self, logs=None):

        if self.stopped_epoch > 0 and self.verbose > 0:

            print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))


    def get_monitor_value(self, logs):

        monitor_value = logs.get(self.monitor)

        if monitor_value is None:

            warnings.warn(

                'Early stopping conditioned on metric `%s` '

                'which is not available. Available metrics are: %s' %

                (self.monitor, ','.join(list(logs.keys()))), RuntimeWarning

            )

        return monitor_value

如果您有其他逻辑,则可以根据您的逻辑使用on_batch_end和设置,但我认为您明白了。self.model.stop_training = True


查看完整回答
反对 回复 2022-12-20
?
红颜莎娜

TA贡献1842条经验 获得超12个赞

您可以使用model.stop_training参数来停止训练。


例如,如果我们想在第 2 轮第 3 批次停止训练,那么您可以执行如下操作。


import keras

from keras.models import Sequential

from keras.layers import Dense

from keras.optimizers import SGD

import numpy as np

import pandas as pd


class My_Callback(keras.callbacks.Callback):

    def on_epoch_begin(self, epoch, logs={}):

      self.epoch = epoch


    def on_batch_end(self, batch, logs={}):

        if self.epoch == 1 and batch == 3:

          print (f"\nStopping at Epoch {self.epoch}, Batch {batch}")

          self.model.stop_training = True



X_train = np.random.random((100, 3))

y_train = pd.get_dummies(np.argmax(X_train[:, :3], axis=1)).values


clf = Sequential()

clf.add(Dense(9, activation='relu', input_dim=3))

clf.add(Dense(3, activation='softmax'))

clf.compile(loss='categorical_crossentropy', optimizer=SGD())


clf.fit(X_train, y_train, epochs=10, batch_size=16, callbacks=[My_Callback()])

输出:


Epoch 1/10

100/100 [==============================] - 0s 337us/step - loss: 1.0860

Epoch 2/10

 16/100 [===>..........................] - ETA: 0s - loss: 1.0830

Stopping at Epoch 1, Batch 3

<keras.callbacks.callbacks.History at 0x7ff2e3eeee10>


查看完整回答
反对 回复 2022-12-20
  • 2 回答
  • 0 关注
  • 162 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信