I już ułożyła jakiś kod przykład tensorflow aby pomóc wyjaśnić (pełny kod pracy jest w tym GIST ). Ten kod implementuje sieć kapsuł z pierwszej części sekcji 2 w dokumencie, który połączyłeś:
N_REC_UNITS = 10
N_GEN_UNITS = 20
N_CAPSULES = 30
# input placeholders
img_input_flat = tf.placeholder(tf.float32, shape=(None, 784))
d_xy = tf.placeholder(tf.float32, shape=(None, 2))
# translate the image according to d_xy
img_input = tf.reshape(img_input_flat, (-1, 28, 28, 1))
trans_img = image.translate(img_input, d_xy)
flat_img = tf.layers.flatten(trans_img)
capsule_img_list = []
# build several capsules and store the generated output in a list
for i in range(N_CAPSULES):
# hidden recognition layer
h_rec = tf.layers.dense(flat_img, N_REC_UNITS, activation=tf.nn.relu)
# inferred xy values
xy = tf.layers.dense(h_rec, 2) + d_xy
# inferred probability of feature
p = tf.layers.dense(h_rec, 1, activation=tf.nn.sigmoid)
# hidden generative layer
h_gen = tf.layers.dense(xy, N_GEN_UNITS, activation=tf.nn.relu)
# the flattened generated image
cap_img = p*tf.layers.dense(h_gen, 784, activation=tf.nn.relu)
capsule_img_list.append(cap_img)
# combine the generated images
gen_img_stack = tf.stack(capsule_img_list, axis=1)
gen_img = tf.reduce_sum(gen_img_stack, axis=1)
Czy ktoś wie, jak powinno działać mapowanie między pikselami wejściowymi a kapsułkami?
To zależy od struktury sieci. W pierwszym eksperymencie w tym artykule (i powyższym kodzie) każda kapsułka ma pole odbiorcze, które obejmuje cały obraz wejściowy. To najprostszy układ. W takim przypadku jest to w pełni połączona warstwa między obrazem wejściowym a pierwszą ukrytą warstwą w każdej kapsułce.
Alternatywnie, pola recepcyjne kapsułki można ułożyć bardziej jak jądra CNN z krokami, jak w późniejszych eksperymentach w tym artykule.
Co dokładnie powinno się dziać w jednostkach rozpoznających?
Jednostki rozpoznające są wewnętrzną reprezentacją każdej kapsułki. Każda kapsułka wykorzystuje tę wewnętrzną reprezentację do obliczenia p
, prawdopodobieństwa obecności funkcji kapsułki i xy
wywnioskowanych wartości translacji. Ryc. 2 w tym dokumencie to sprawdzenie, czy sieć uczy się obsługiwać xy
poprawnie (tak jest).
Jak należy go szkolić? Czy to tylko standardowa tylna podpora między każdym połączeniem?
W szczególności powinieneś trenować go jako autoencoder, używając straty, która wymusza podobieństwo między generowanym wyjściem a oryginałem. Średni błąd kwadratowy działa tutaj dobrze. Poza tym tak, musisz propagować opadanie gradientu za pomocą backprop.
loss = tf.losses.mean_squared_error(img_input_flat, gen_img)
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)