W jaki sposób opadanie gradientu minibatch aktualizuje wagi dla każdego przykładu w partii?


12

Jeśli przetwarzamy powiedzmy 10 przykładów w partii, rozumiem, że możemy zsumować stratę dla każdego przykładu, ale jak działa propagacja wsteczna w odniesieniu do aktualizacji wag dla każdego przykładu?

Na przykład:

  • Przykład 1 -> strata = 2
  • Przykład 2 -> strata = -2

Powoduje to średnią stratę 0 (E = 0), więc w jaki sposób zaktualizowałaby każdą wagę i zbiegała się? Czy to po prostu przez losowość partii, które „miejmy nadzieję” zbiegną się wcześniej czy później? Czy to również nie oblicza gradientu dla pierwszego zestawu wag dla ostatniego przetworzonego przykładu?

Odpowiedzi:


15

Spadek gradientu nie działa tak, jak sugerowałeś, ale może wystąpić podobny problem.

Nie obliczamy średniej straty z partii, obliczamy średnie gradienty funkcji straty. Gradienty są pochodną straty w stosunku do masy, aw sieci neuronowej gradient dla jednej masy zależy od danych wejściowych z tego konkretnego przykładu i zależy również od wielu innych wag w modelu.

Jeśli twój model ma 5 ciężarków, a rozmiar mini-partii wynosi 2, możesz otrzymać:

Przykład 1. Strata = 2,gradients=(1.5,2.0,1.1,0.4,0.9)

Przykład 2. Strata = 3,gradients=(1.2,2.3,1.1,0.8,0.7)

Obliczane są średnie gradientów w tej mini-partii, wynoszą one(1.35,0.15,0,0.2,0.8)

Zaletą uśrednienia dla kilku przykładów jest to, że zmienność gradientu jest mniejsza, więc uczenie się jest bardziej spójne i mniej zależne od specyfiki jednego przykładu. Zauważ, że średni gradient dla trzeciej wagi wynosi , ta waga nie zmieni tej aktualizacji wagi, ale prawdopodobnie będzie różna od zera dla kolejnych wybranych przykładów, które zostaną obliczone z różnymi wagami.0

edytuj w odpowiedzi na komentarze:

W moim przykładzie powyżej obliczana jest średnia gradientów. Dla wielkości mini partii której obliczamy stratę dla każdego przykładu, dążymy do uzyskania średniego gradientu straty w stosunku do wagi .L i w jkLiwj

Sposób, w jaki napisałem to w moim przykładzie, uśredniłem każdy gradient, np .:Lwj=1ki=1kLiwj

Kod samouczka, do którego prowadzisz link w komentarzach, wykorzystuje Tensorflow w celu zminimalizowania średniej straty.

Tensorflow ma na celu zminimalizowanie1ki=1kLi

Aby to zminimalizować, oblicza gradienty średniej straty w odniesieniu do każdej masy i wykorzystuje gradient opadający do aktualizacji wag:

Lwj=wj1ki=1kLi

Zróżnicowanie można wprowadzić do sumy, więc jest takie samo jak wyrażenie z podejścia w moim przykładzie.

wj1ki=1kLi=1ki=1kLiwj


Gotcha Nadal chcesz uśrednić stratę w stosunku do rozmiaru partii_właściwej? Nie jestem pewien, czy znasz tensorflow, ale starałem się pogodzić moje zrozumienie z tym samouczkiem: tensorflow.org/get_started/mnist/beginners Widać, że strata jest uśredniana dla partii (kod redukcyjny). Przypuszczam, że tensorflow zachowuje wewnętrzną liczbę / średnie masy?
obliczony na podstawie węgla

1
@carboncomputed O tak, masz rację, uśredniają stratę, więc gdy Tensorflow oblicza gradienty średniej straty, skutecznie oblicza średnią gradientów dla każdej straty. Zmodyfikuję swoją odpowiedź, aby pokazać matematykę.
Hugh

Ciekawy. Dziękuję za wyjaśnienie. Tak więc, aby kopać nieco głębiej, czy gradienty masy są obliczane na przykład podczas przechodzenia do przodu i przechowywane, czy są obliczane podczas procesu optymalizacji w tensorflow? Przypuszczam, że po prostu brakuje mi „gdzie” są te gradienty w przepływie tensorowym? Widzę przejście do przodu i stratę, więc tensorflow wykonuje dla mnie te obliczenia gradientu / uśrednianie pod maską?
obliczony na podstawie węgla

1
@carboncomputed Taki urok Tensorflow wykorzystuje matematykę symboliczną i potrafi rozróżniać pod maską
Hugh

Dziękuję za zgrabną odpowiedź. Jednak nie udało mi się zrozumieć, jak TF wie, jak wykonać kopię propagować ze średnią stratę, jak pokazano w tym przykładzie , code line 170?
grzesznik

-1

Powodem używania mini-partii jest dobry przykład szkolenia, tak aby ewentualny hałas został zmniejszony przez uśrednienie ich efektów, ale nie jest to również pełna partia, która dla wielu zestawów danych może wymagać dużej ilości pamięci. Jednym ważnym faktem jest to, że błąd, który oceniasz, jest zawsze odległościąmiędzy przewidywaną wydajnością a rzeczywistą wydajnością: oznacza to, że nie może być ona ujemna, więc nie można, jak powiedziano, błędu 2 i -2, który się anuluje, ale zamiast tego stałby się błędem 4 Następnie oceniasz gradient błędu w odniesieniu do wszystkich wag, abyś mógł obliczyć, która zmiana wag zmniejszyłaby go najbardziej. Kiedy to zrobisz, zrobisz „krok” w tym kierunku, w zależności od wielkości twojego współczynnika uczenia się alfa. (To są podstawowe pojęcia, nie będę szczegółowo omawiał propagacji wstecznej dla głębokiego NN) Po uruchomieniu tego szkolenia w zbiorze danych dla pewnej liczby epok, możesz oczekiwać, że twoja sieć zbiegnie się, jeśli twój krok nauki nie jest zbyt duży spraw, by się rozchodził. Nadal możesz skończyć w lokalnym minimummożna tego uniknąć, inicjując różne wagi w inny sposób, używając optymalizatorów różnic i próbując uregulować.


Wystarczy dodać: używamy mini-partii głównie w celu zwiększenia wydajności obliczeniowej. Mamy kompromis między dokładnością zejścia a częstotliwością aktualizacji wag. Dane muszą być bardzo duże, aby nie mieściły się w pamięci.
Łukasz Grad

Rozumiem każdy, ale jak zaktualizować nasze wagi dla konkretnej partii? Czy gradienty masy są również sumowane dla każdego przykładu?
obliczony na podstawie węgla

Nie, istnieje tylko jeden gradient, który jest wektorem pochodnych, na całkowitym błędzie wsadu. Oznacza to, że raz aktualizujemy nasze wagi w oparciu o gradient, tj. Kierunek aktualizacji, który powoduje, że błąd w tej mini partii zmniejsza się najbardziej. Gradient składa się z częściowych pochodnych, to jest pochodnych małego błędu wsadowego w odniesieniu do każdej masy: to mówi nam, czy każda waga powinna być mniejsza lub większa i ile. Wszystkie wagi otrzymują jedną aktualizację dla partii, aby zmniejszyć błąd w tej mini partii, która jest niezależna od innych mini partii.
dante
Korzystając z naszej strony potwierdzasz, że przeczytałeś(-aś) i rozumiesz nasze zasady używania plików cookie i zasady ochrony prywatności.
Licensed under cc by-sa 3.0 with attribution required.