Jakich parametrów należy użyć do wczesnego zatrzymania?


100

Trenuję sieć neuronową dla mojego projektu przy użyciu Keras. Keras zapewnia funkcję wczesnego zatrzymywania. Czy mogę wiedzieć, jakie parametry należy obserwować, aby uniknąć nadmiernego dopasowania mojej sieci neuronowej przez zastosowanie wczesnego zatrzymywania?

Odpowiedzi:


160

wczesne zatrzymanie

Wczesne zatrzymanie to po prostu zatrzymanie treningu, gdy strata zacznie rosnąć (innymi słowy, dokładność walidacji zaczyna spadać). Według dokumentów jest używany w następujący sposób;

keras.callbacks.EarlyStopping(monitor='val_loss',
                              min_delta=0,
                              patience=0,
                              verbose=0, mode='auto')

Wartości zależą od implementacji (problem, wielkość partii itp.), Ale ogólnie, aby zapobiec nadmiernemu dopasowaniu, użyłbym;

  1. Monitoruj utratę walidacji (musisz użyć walidacji krzyżowej lub przynajmniej trenować / testować zestawy), ustawiając monitor argument na 'val_loss'.
  2. min_deltajest progiem pozwalającym określić ilościowo stratę w pewnym okresie jako poprawę, czy nie. Jeżeli różnica strat jest niższa min_delta, określa się ją ilościowo jako brak poprawy. Lepiej zostawić wartość 0, ponieważ interesuje nas, kiedy strata się pogarsza.
  3. patienceArgument reprezentuje liczbę epok przed zakończeniem, gdy strata zacznie rosnąć (przestanie się poprawiać). Zależy to od implementacji, jeśli używasz bardzo małych partii lub dużego tempa uczenia się, twoja strata będzie zygzakowata (dokładność będzie bardziej hałaśliwa), więc lepiej ustaw duży patienceargument. Jeśli używasz dużych partii i mały wskaźnik uczenia się, Twoja strata będzie gładsza, więc możesz użyć mniejszego patienceargumentu. Tak czy inaczej zostawię to jako 2, więc dałbym modelowi więcej szans.
  4. verbose decyduje, co wydrukować, pozostaw to domyślne (0).
  5. modeArgument zależy od tego, w jakim kierunku ma monitorowana ilość (czy ma się zmniejszać, czy zwiększać), ponieważ monitorujemy stratę, możemy użyć min. Ale zostawmy to kerasowi, który zajmie się tym za nas i ustawmy to naauto

Więc użyłbym czegoś takiego i eksperymentowałbym, wykreślając utratę błędu z wczesnym zatrzymaniem i bez niego.

keras.callbacks.EarlyStopping(monitor='val_loss',
                              min_delta=0,
                              patience=2,
                              verbose=0, mode='auto')

Aby uzyskać ewentualną niejednoznaczność dotyczącą działania wywołań zwrotnych, spróbuję wyjaśnić więcej. Raz zadzwoniszfit(... callbacks=[es]) modelu Keras wywołuje określone obiekty wywołania zwrotnego z predefiniowanymi funkcjami. Funkcje te można nazwać on_train_begin, on_train_end, on_epoch_begin, on_epoch_endi on_batch_begin,on_batch_end . Wczesne zatrzymanie wywołania zwrotnego jest wywoływane na każdym końcu epoki, porównuje najlepiej monitorowaną wartość z bieżącą i zatrzymuje się, jeśli spełnione są warunki (ile epok minęło od momentu zaobserwowania najlepiej monitorowanej wartości i czy jest to coś więcej niż argument cierpliwości, różnica między ostatnia wartość jest większa niż min_delta itp.).

Jak wskazał @BrentFaust w komentarzach, uczenie modelu będzie kontynuowane, dopóki nie zostaną spełnione warunki wczesnego zatrzymania lub epochsparametr (domyślnie = 10) w fit()zostanie spełniony. Ustawienie wywołania zwrotnego wczesnego zatrzymania nie spowoduje, że model będzie trenował poza jego epochsparametrem. Więc wywołanie fit()funkcji z większymepochs wartości przyniosłoby większe korzyści z wywołania zwrotnego Early Stopping.


3
@AizuddinAzman close, min_deltajest progiem określającym ilościowo zmianę monitorowanej wartości jako poprawę, czy też nie. Więc tak, jeśli podamy, monitor = 'val_loss'to odnosi się to do różnicy między bieżącą utratą walidacji a poprzednią utratą walidacji. W praktyce, jeśli podasz min_delta=0.1spadek utraty walidacji (bieżący - poprzedni) mniejszy niż 0,1, nie będzie to kwantyfikować, a tym samym zatrzyma szkolenie (jeśli masz patience = 0).
umutto

3
Zauważ, że callbacks=[EarlyStopping(patience=2)]nie ma to żadnego efektu, chyba że podano epoki model.fit(..., epochs=max_epochs).
Brent Faust

1
@BrentFaust To również rozumiem, napisałem odpowiedź przy założeniu, że model jest trenowany z co najmniej 10 epokami (domyślnie). Po twoim komentarzu zdałem sobie sprawę, że może zaistnieć przypadek, w którym programista wywołuje fit with epoch=1w pętli for (dla różnych przypadków użycia), w którym to wywołanie zwrotne zakończy się niepowodzeniem. Jeśli w mojej odpowiedzi jest dwuznaczność, spróbuję to ująć w lepszy sposób.
umutto

4
@AdmiralWen Odkąd napisałem odpowiedź, kod nieco się zmienił. Jeśli korzystasz z najnowszej wersji Keras, możesz skorzystać z restore_best_weightsargumentu (jeszcze nie w dokumentacji), który po treningu wczytuje model z najlepszymi wagami. Ale dla twoich celów ModelCheckpointużyłbym callback z save_best_onlyargumentem. Możesz sprawdzić dokumentację, jest prosta w obsłudze, ale po treningu musisz ręcznie załadować najlepsze ciężarki.
umutto

1
@umutto Witaj, dziękuję za sugestię pliku restore_best_weights, jednak nie mogę go użyć, `es = EarlyStopping (monitor = 'val_acc', min_delta = 1e-4, cierpliwość = cierpliwość_, verbose = 1, restore_best_weights = True) TypeError: Funkcja __init __ () otrzymała nieoczekiwany argument słowa kluczowego „restore_best_weights”. Jakieś pomysły? keras 2.2.2, tf, 1.10 jaka jest twoja wersja?
Haramoz,
Korzystając z naszej strony potwierdzasz, że przeczytałeś(-aś) i rozumiesz nasze zasady używania plików cookie i zasady ochrony prywatności.
Licensed under cc by-sa 3.0 with attribution required.