Próbuję zrozumieć rolę tej Flatten
funkcji w Keras. Poniżej znajduje się mój kod, który jest prostą siecią dwuwarstwową. Pobiera dwuwymiarowe dane kształtu (3, 2) i generuje jednowymiarowe dane kształtu (1, 4):
model = Sequential()
model.add(Dense(16, input_shape=(3, 2)))
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dense(4))
model.compile(loss='mean_squared_error', optimizer='SGD')
x = np.array([[[1, 2], [3, 4], [5, 6]]])
y = model.predict(x)
print y.shape
To drukuje, że y
ma kształt (1, 4). Jeśli jednak usunę Flatten
linię, to wydrukuje się, która y
ma kształt (1, 3, 4).
Nie rozumiem tego. Z mojego rozumienia sieci neuronowych model.add(Dense(16, input_shape=(3, 2)))
wynika , że funkcja ta tworzy ukrytą, w pełni połączoną warstwę z 16 węzłami. Każdy z tych węzłów jest podłączony do każdego z elementów wejściowych 3x2. Dlatego 16 węzłów na wyjściu tej pierwszej warstwy jest już „płaskich”. Zatem wyjściowy kształt pierwszej warstwy powinien wynosić (1, 16). Następnie druga warstwa przyjmuje to jako dane wejściowe i wyprowadza dane kształtu (1, 4).
Jeśli więc wydruk pierwszej warstwy jest już „płaski” i ma kształt (1, 16), po co mam go dalej spłaszczać?
Dense(16, input_shape=(5,3)
czy każdy neuron wyjściowy z zestawu 16 (i dla wszystkich 5 zestawów tych neuronów) będzie połączony ze wszystkimi (3 x 5 = 15) neuronami wejściowymi? Czy też każdy neuron w pierwszym zestawie 16 będzie połączony tylko z 3 neuronami w pierwszym zestawie 5 neuronów wejściowych, a następnie każdy neuron w drugim zestawie 16 będzie podłączony tylko do 3 neuronów w drugim zestawie 5 wejść neurony, itp .... Nie wiem, który to jest!