Najpierw wyprowadzam błąd dla poniższej warstwy splotowej dla uproszczenia dla jednowymiarowej tablicy (danych wejściowych), którą można łatwo przenieść na wielowymiarowy:
Zakładamy tutaj, że yl - 1 długości N. są wejściami l - 1-ty konw. warstwa,m jest wagą wielkości jądra w oznaczając każdą wagę przez wja i wynik jest xl.
Dlatego możemy napisać (zanotuj sumę od zera):
xlja=∑a = 0m - 1wzayl - 1a + i
gdzie
ylja= f(xlja) i
fafunkcja aktywacji (np. sigmoidalna). Mając to pod ręką, możemy teraz rozważyć pewną funkcję błędu
mi oraz funkcja błędu w warstwie splotowej (tej z poprzedniej warstwy) podana przez
∂mi/ ∂ylja. Chcemy teraz dowiedzieć się, jaka jest zależność błędu w jednym z wag poprzednich warstw:
∂mi∂wza=∑a = 0N.- m∂mi∂xlja∂xlja∂wza=∑a = 0N.- m∂mi∂wzayl - 1i + a
gdzie mamy sumę nad wszystkimi wyrażeniami, w których
wza występuje, które są
N.- m. Należy również pamiętać, że wiemy, że ostatni termin wynika z faktu, że
∂xlja∂wza=yl - 1i + aco widać z pierwszego równania.
Aby obliczyć gradient, musimy znać pierwszy termin, który można obliczyć:
∂mi∂xlja=∂mi∂ylja∂ylja∂xlja=∂mi∂ylja∂∂xljafa(xlja)
gdzie znowu pierwszym terminem jest błąd w poprzedniej warstwie i
fa nieliniowa funkcja aktywacji.
Mając wszystkie niezbędne byty, jesteśmy w stanie obliczyć błąd i skutecznie propagować go z powrotem do cennej warstwy:
δl - 1za=∂mi∂yl - 1ja=∑a = 0m - 1∂mi∂xli - a∂xli - a∂yl - 1ja=∑a = 0m - 1∂mi∂xli - awfal i p p e dza
Pamiętaj, że ostatni krok można łatwo zrozumieć, zapisując
xli-s wrt
yl−1i-s. The
flipped odnosi się do transponowanej masy maxtrix (
T).
Dlatego możesz po prostu obliczyć błąd w następnej warstwie przez (teraz w notacji wektorowej):
δl=(wl)Tδl+1f′(xl)
która staje się warstwą splotową i podpróbkowania:
δl=upsample((wl)Tδl+1)f′(xl)
gdzie
upsample Operacja propaguje błąd przez maksymalną warstwę puli.
Dodaj mnie lub popraw!
Odniesienia patrz:
http://ufldl.stanford.edu/tutorial/supervised/ConvolutionalNeuralNetwork/
http://andrew.gibiansky.com/blog/machine-learning/convolutional-neural-networks/
i dla implementacji C ++ (bez konieczności instalacji):
https://github.com/nyanp/tiny-cnn#supported-networks