Ekstremalnie małe wartości lub NaN pojawiają się w uczącej sieci neuronowej


329

Próbuję zaimplementować architekturę sieci neuronowej w Haskell i używać jej na MNIST.

Używam hmatrixpakietu do algebry liniowej. Moja struktura szkoleniowa jest zbudowana przy użyciupipes pakietu.

Mój kod kompiluje się i nie ulega awarii. Ale problem polega na tym, że pewne kombinacje rozmiaru warstwy (powiedzmy, 1000), rozmiaru minibatchu i szybkości uczenia się powodują NaNwartości w obliczeniach. Po krótkiej inspekcji widzę, że bardzo małe wartości (rzędu1e-100 ) w końcu pojawiają się w aktywacjach. Ale nawet jeśli tak się nie stanie, trening nadal nie działa. Nie ma poprawy w zakresie utraty lub dokładności.

Sprawdziłem i ponownie sprawdziłem swój kod i nie wiem, co może być przyczyną problemu.

Oto trening wstecznej propagacji, który oblicza delty dla każdej warstwy:

backward lf n (out,tar) das = do
    let δout = tr (derivate lf (tar, out)) -- dE/dy
        deltas = scanr (\(l, a') δ ->
                         let w = weights l
                         in (tr a') * (w <> δ)) δout (zip (tail $ toList n) das)
    return (deltas)

lfjest funkcją strat, njest siecią ( weightmacierzą i biaswektorami dla każdej warstwy) outi tarsą faktycznym wyjściem sieci i target(pożądanym) wyjściem, i dassą pochodnymi aktywacji każdej warstwy.

W trybie wsadowym out, tarsą macierzami (wiersze są wektorami wyjściowymi) i dasjest listą macierzy.

Oto rzeczywiste obliczenie gradientu:

  grad lf (n, (i,t)) = do
    -- Forward propagation: compute layers outputs and activation derivatives
    let (as, as') = unzip $ runLayers n i
        (out) = last as
    (ds) <- backward lf n (out, t) (init as') -- Compute deltas with backpropagation
    let r  = fromIntegral $ rows i -- Size of minibatch
    let gs = zipWith (\δ a -> tr (δ <> a)) ds (i:init as) -- Gradients for weights
    return $ GradBatch ((recip r .*) <$> gs, (recip r .*) <$> squeeze <$> ds)

Tutaj, lfi nsą takie same jak powyżej, ijest wejściem i tjest wyjściem docelowym (oba w postaci wsadowej, jako macierze).

squeezeprzekształca macierz w wektor, sumując w każdym wierszu. Oznacza to, że dsjest to lista macierzy delt, gdzie każda kolumna odpowiada deltom wiersza minibatchu. Tak więc, gradienty odchyleń są średnią delt na całej minibatch. To samo dotyczy gs, co odpowiada gradientom wag.

Oto rzeczywisty kod aktualizacji:

move lr (n, (i,t)) (GradBatch (gs, ds)) = do
    -- Update function
    let update = (\(FC w b af) g δ -> FC (w + (lr).*g) (b + (lr).*δ) af)
        n' = Network.fromList $ zipWith3 update (Network.toList n) gs ds
    return (n', (i,t))

lrto współczynnik uczenia się. FCjest konstruktorem warstwy i affunkcją aktywacji tej warstwy.

Algorytm zstępowania gradientu zapewnia przekazanie ujemnej wartości szybkości uczenia się. Rzeczywisty kod opadania gradientu to po prostu pętla wokół kompozycji gradi moveze sparametryzowanym warunkiem zatrzymania.

Na koniec, oto kod funkcji średniej kwadratowej utraty błędu:

mse :: (Floating a) => LossFunction a a
mse = let f (y,y') = let gamma = y'-y in gamma**2 / 2
          f' (y,y') = (y'-y)
      in  Evaluator f f'

Evaluator po prostu łączy funkcję straty i jej pochodną (do obliczenia delty warstwy wyjściowej).

Reszta kodu znajduje się na GitHub: NeuralNetwork .

Byłbym więc wdzięczny, gdyby ktoś miał wgląd w problem lub choćby po prostu sprawdził poczytalność, czy poprawnie implementuję algorytm.


17
Dzięki, przyjrzę się temu. Ale nie sądzę, żeby to było normalne zachowanie. O ile wiem, inne implementacje tego, co próbuję zrobić (prosta, w pełni połączona sieć neuronowa ze sprzężeniem zwrotnym), czy to w języku Haskell, czy w innych językach, nie wydają się tego robić.
Charles Langlois

17
@Charles: Czy faktycznie wypróbowałeś własne sieci i zestawy danych z wymienionymi innymi implementacjami? Z własnego doświadczenia wynika, że ​​BP łatwo wpadnie w szał, gdy NN nie będzie dobrze dopasowany do problemu. Jeśli masz wątpliwości co do swojej implementacji BP, możesz porównać jego wynik z naiwnym obliczeniem gradientu (oczywiście na NN o rozmiarze zabawki) - co jest o wiele trudniejsze do popełnienia błędu niż BP.
shinobi

5
Czy MNIST nie jest typowym problemem klasyfikacyjnym? Dlaczego używasz MES? Powinieneś używać softmax crossentropy (obliczonego z logitów) nie?
mdaoust

6
@CharlesLanglois, to może nie być twój problem (nie mogę odczytać kodu), ale „średni kwadratowy błąd” nie jest wypukły w przypadku problemu z klasyfikacją, co może wyjaśniać utknięcie. „logity” to po prostu fantazyjny sposób na określenie log-szans: użyj ce = x_j - log(sum_i(exp(x)))obliczeń z tego miejsca, aby nie brać dziennika wykładniczego (który często generuje wartości NaN)
mdaoust

6
Gratulujemy uzyskania najwyżej głosowanego pytania (stan na styczeń '20) bez głosów za lub zaakceptowanych!
hongsy

Odpowiedzi:


2

Czy wiesz o „znikających” i „eksplodujących” gradientach we wstecznej propagacji? Nie jestem zbyt zaznajomiony z Haskellem, więc nie mogę łatwo zobaczyć, co dokładnie robi twoja tylna podpórka, ale wygląda na to, że używasz krzywej logistycznej jako funkcji aktywacji.

Jeśli spojrzysz na wykres tej funkcji, zobaczysz, że gradient tej funkcji jest prawie 0 na końcach (ponieważ wartości wejściowe stają się bardzo duże lub bardzo małe, nachylenie krzywej jest prawie płaskie), więc mnożenie lub dzielenie przez to podczas wstecznej propagacji spowoduje bardzo dużą lub bardzo małą liczbę. Powtarzanie tego podczas przechodzenia przez wiele warstw powoduje, że aktywacje zbliżają się do zera lub nieskończoności. Ponieważ backprop aktualizuje twoje wagi, robiąc to podczas treningu, w twojej sieci jest dużo zer lub nieskończoności.

Rozwiązanie: istnieje wiele metod, których możesz szukać, aby rozwiązać problem znikającego gradientu, ale jedną łatwą rzeczą do wypróbowania jest zmiana typu używanej funkcji aktywacji na nienasycającą. ReLU jest popularnym wyborem, ponieważ łagodzi ten konkretny problem (ale może wprowadzić inne).

Korzystając z naszej strony potwierdzasz, że przeczytałeś(-aś) i rozumiesz nasze zasady używania plików cookie i zasady ochrony prywatności.
Licensed under cc by-sa 3.0 with attribution required.