Jak powiedzieć Kerasowi, że przestanie trenować na podstawie wartości strat?


82

Obecnie używam następującego kodu:

callbacks = [
    EarlyStopping(monitor='val_loss', patience=2, verbose=0),
    ModelCheckpoint(kfold_weights_path, monitor='val_loss', save_best_only=True, verbose=0),
]
model.fit(X_train.astype('float32'), Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
      shuffle=True, verbose=1, validation_data=(X_valid, Y_valid),
      callbacks=callbacks)

Mówi Kerasowi, aby przestał trenować, jeśli straty nie poprawiły się przez 2 epoki. Ale chcę przestać trenować po tym, jak strata stała się mniejsza niż jakieś stałe „THR”:

if val_loss < THR:
    break

Widziałem w dokumentacji, że istnieje możliwość wykonania własnego oddzwonienia: http://keras.io/callbacks/ Ale nic nie znalazło sposobu na zatrzymanie procesu szkolenia. Potrzebuję rady.

Odpowiedzi:


85

Znalazłem odpowiedź. Zajrzałem do źródeł Keras i znalazłem kod do EarlyStopping. Na tej podstawie wykonałem własny callback:

class EarlyStoppingByLossVal(Callback):
    def __init__(self, monitor='val_loss', value=0.00001, verbose=0):
        super(Callback, self).__init__()
        self.monitor = monitor
        self.value = value
        self.verbose = verbose

    def on_epoch_end(self, epoch, logs={}):
        current = logs.get(self.monitor)
        if current is None:
            warnings.warn("Early stopping requires %s available!" % self.monitor, RuntimeWarning)

        if current < self.value:
            if self.verbose > 0:
                print("Epoch %05d: early stopping THR" % epoch)
            self.model.stop_training = True

I zastosowanie:

callbacks = [
    EarlyStoppingByLossVal(monitor='val_loss', value=0.00001, verbose=1),
    # EarlyStopping(monitor='val_loss', patience=2, verbose=0),
    ModelCheckpoint(kfold_weights_path, monitor='val_loss', save_best_only=True, verbose=0),
]
model.fit(X_train.astype('float32'), Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
      shuffle=True, verbose=1, validation_data=(X_valid, Y_valid),
      callbacks=callbacks)

1
Tylko jeśli komuś się przyda - w moim przypadku użyłem monitor = 'loss', zadziałało dobrze.
QtRoS

15
Wygląda na to, że Keras został zaktualizowany. Funkcja wywołania zwrotnego EarlyStopping ma teraz wbudowaną min_delta. Nie ma już potrzeby hakowania kodu źródłowego, yay! stackoverflow.com/a/41459368/3345375
jkdev

3
Po ponownym przeczytaniu pytania i odpowiedzi muszę się poprawić: min_delta oznacza „Zatrzymaj się wcześnie, jeśli nie ma wystarczającej poprawy na epokę (lub na wiele epok)”. Jednak OP zapytał, jak „zatrzymać się wcześnie, gdy strata spadnie poniżej pewnego poziomu”.
jkdev

NameError: name 'Callback' nie jest zdefiniowana ... Jak to naprawić?
alyssaeliyah

2
Eliyah spróbuj tego: from keras.callbacks import Callback
ZFTurbo

26

Callback keras.callbacks.EarlyStopping ma argument min_delta. Z dokumentacji Keras:

min_delta: minimalna zmiana ilości monitorowanej kwalifikująca się jako poprawa, tj. zmiana bezwzględna mniejsza niż min_delta, będzie się liczyć jako brak poprawy.


3
Dla porównania, oto dokumentacja dla wcześniejszej wersji Keras (1.1.0), w której argument min_delta nie został jeszcze uwzględniony: faroit.github.io/keras-docs/1.1.0/callbacks/#earlystopping
jkdev

jak mogłem sprawić, by to nie ustało, dopóki min_deltanie będzie trwało przez wiele epok?
zyxue

jest jeszcze jeden parametr EarlyStopping zwany cierpliwością: liczba okresów bez poprawy, po których trening zostanie zatrzymany.
devin

13

Jednym z rozwiązań jest wywołanie model.fit(nb_epoch=1, ...)wewnątrz pętli for, a następnie umieszczenie instrukcji break wewnątrz pętli for i wykonanie dowolnego innego niestandardowego przepływu sterowania.


Byłoby miło, gdyby wykonali funkcję zwrotną, która przyjmuje jedną funkcję, która może to zrobić.
Uczciwość

8

Rozwiązałem ten sam problem, używając niestandardowego wywołania zwrotnego.

W poniższym niestandardowym kodzie wywołania zwrotnego przypisz THR wartość, przy której chcesz zatrzymać uczenie i dodać wywołanie zwrotne do swojego modelu.

from keras.callbacks import Callback

class stopAtLossValue(Callback):

        def on_batch_end(self, batch, logs={}):
            THR = 0.03 #Assign THR with the value at which you want to stop training.
            if logs.get('loss') <= THR:
                 self.model.stop_training = True

2

Chociaż byłem biorąc TensorFlow specjalizacji w praktyce , nauczyłem się bardzo elegancką technikę. Niewiele zmodyfikowano w stosunku do zaakceptowanej odpowiedzi.

Dajmy przykład naszym ulubionym danym MNIST.

import tensorflow as tf

class new_callback(tf.keras.callbacks.Callback):
    def epoch_end(self, epoch, logs={}): 
        if(logs.get('accuracy')> 0.90): # select the accuracy
            print("\n !!! 90% accuracy, no further training !!!")
            self.model.stop_training = True

mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0 #normalize

callbacks = new_callback()

# model = tf.keras.models.Sequential([# define your model here])

model.compile(optimizer=tf.optimizers.Adam(),
          loss='sparse_categorical_crossentropy',
          metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])

Więc tutaj ustawiam metrics=['accuracy'], a więc w klasie wywołania zwrotnego warunek jest ustawiony na 'accuracy'> 0.90.

Możesz wybrać dowolną metrykę i monitorować szkolenie, tak jak w tym przykładzie. Co najważniejsze, możesz ustawić różne warunki dla różnych danych i używać ich jednocześnie.

Mam nadzieję, że to pomoże!


nazwa funkcji powinna znajdować się on_epoch_end
xarion

0

Dla mnie model przestałby trenować tylko wtedy, gdybym dodał instrukcję return po ustawieniu parametru stop_training na True, ponieważ dzwoniłem po self.model.evaluate. Więc upewnij się, że na końcu funkcji umieścisz stop_training = True lub dodaj instrukcję return.

def on_epoch_end(self, batch, logs):
        self.epoch += 1
        self.stoppingCounter += 1
        print('\nstopping counter \n',self.stoppingCounter)

        #Stop training if there hasn't been any improvement in 'Patience' epochs
        if self.stoppingCounter >= self.patience:
            self.model.stop_training = True
            return

        # Test on additional set if there is one
        if self.testingOnAdditionalSet:
            evaluation = self.model.evaluate(self.val2X, self.val2Y, verbose=0)
            self.validationLoss2.append(evaluation[0])
            self.validationAcc2.append(evaluation[1])enter code here

0

Jeśli używasz niestandardowej pętli treningowej, możesz użyć collections.dequelisty „toczącej się”, którą można dołączyć, a pozycje po lewej stronie zostaną wyskakujące, gdy lista jest dłuższa niż maxlen. Oto linia:

loss_history = deque(maxlen=early_stopping + 1)

for epoch in range(epochs):
    fit(epoch)
    loss_history.append(test_loss.result().numpy())
    if len(loss_history) > early_stopping and loss_history.popleft() < min(loss_history)
            break

Oto pełny przykład:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow.keras.layers import Dense
from collections import deque

data, info = tfds.load('iris', split='train', as_supervised=True, with_info=True)

data = data.map(lambda x, y: (tf.cast(x, tf.int32), y))

train_dataset = data.take(120).batch(4)
test_dataset = data.skip(120).take(30).batch(4)

model = tf.keras.models.Sequential([
    Dense(8, activation='relu'),
    Dense(16, activation='relu'),
    Dense(info.features['label'].num_classes)])

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_loss = tf.keras.metrics.Mean()
test_loss = tf.keras.metrics.Mean()

train_acc = tf.keras.metrics.SparseCategoricalAccuracy()
test_acc = tf.keras.metrics.SparseCategoricalAccuracy()

opt = tf.keras.optimizers.Adam(learning_rate=1e-3)


@tf.function
def train_step(inputs, labels):
    with tf.GradientTape() as tape:
        logits = model(inputs, training=True)
        loss = loss_object(labels, logits)

    gradients = tape.gradient(loss, model.trainable_variables)
    opt.apply_gradients(zip(gradients, model.trainable_variables))
    train_loss(loss)
    train_acc(labels, logits)


@tf.function
def test_step(inputs, labels):
    logits = model(inputs, training=False)
    loss = loss_object(labels, logits)
    test_loss(loss)
    test_acc(labels, logits)


def fit(epoch):
    template = 'Epoch {:>2} Train Loss {:.3f} Test Loss {:.3f} ' \
               'Train Acc {:.2f} Test Acc {:.2f}'

    train_loss.reset_states()
    test_loss.reset_states()
    train_acc.reset_states()
    test_acc.reset_states()

    for X_train, y_train in train_dataset:
        train_step(X_train, y_train)

    for X_test, y_test in test_dataset:
        test_step(X_test, y_test)

    print(template.format(
        epoch + 1,
        train_loss.result(),
        test_loss.result(),
        train_acc.result(),
        test_acc.result()
    ))


def main(epochs=50, early_stopping=10):
    loss_history = deque(maxlen=early_stopping + 1)

    for epoch in range(epochs):
        fit(epoch)
        loss_history.append(test_loss.result().numpy())
        if len(loss_history) > early_stopping and loss_history.popleft() < min(loss_history):
            print(f'\nEarly stopping. No validation loss '
                  f'improvement in {early_stopping} epochs.')
            break

if __name__ == '__main__':
    main(epochs=250, early_stopping=10)
Epoch  1 Train Loss 1.730 Test Loss 1.449 Train Acc 0.33 Test Acc 0.33
Epoch  2 Train Loss 1.405 Test Loss 1.220 Train Acc 0.33 Test Acc 0.33
Epoch  3 Train Loss 1.173 Test Loss 1.054 Train Acc 0.33 Test Acc 0.33
Epoch  4 Train Loss 1.006 Test Loss 0.935 Train Acc 0.33 Test Acc 0.33
Epoch  5 Train Loss 0.885 Test Loss 0.846 Train Acc 0.33 Test Acc 0.33
...
Epoch 89 Train Loss 0.196 Test Loss 0.240 Train Acc 0.89 Test Acc 0.87
Epoch 90 Train Loss 0.195 Test Loss 0.239 Train Acc 0.89 Test Acc 0.87
Epoch 91 Train Loss 0.195 Test Loss 0.239 Train Acc 0.89 Test Acc 0.87
Epoch 92 Train Loss 0.194 Test Loss 0.239 Train Acc 0.90 Test Acc 0.87

Early stopping. No validation loss improvement in 10 epochs.
Korzystając z naszej strony potwierdzasz, że przeczytałeś(-aś) i rozumiesz nasze zasady używania plików cookie i zasady ochrony prywatności.
Licensed under cc by-sa 3.0 with attribution required.