TensorFlow, dlaczego po zapisaniu modelu są 3 pliki?


113

Po przeczytaniu dokumentów zapisałem model w TensorFlow, oto mój kod demo:

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  ..
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in file: %s" % save_path)

ale potem znalazłem 3 pliki

model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta

I nie mogę przywrócić modelu, przywracając plik model.ckpt plik, ponieważ nie ma takiego pliku. Oto mój kod

with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")

Dlaczego więc są 3 pliki?


2
Czy wiesz, jak rozwiązać ten problem? Jak mogę ponownie załadować model (używając Keras)?
rajkiran

Odpowiedzi:


116

Spróbuj tego:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta')
    saver.restore(sess, "/tmp/model.ckpt")

Metoda zapisywania TensorFlow zapisuje trzy rodzaje plików, ponieważ przechowuje strukturę wykresu oddzielnie od wartości zmiennych . .metaPlik opisuje strukturę zapisany wykres, więc trzeba zaimportować go przed odtworzeniem punktu kontrolnego (w przeciwnym razie nie wie, co zmienne zapisane wartości odpowiadają punktów kontrolnych).

Alternatywnie możesz to zrobić:

# Recreate the EXACT SAME variables
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")

...

# Now load the checkpoint variable values
with tf.Session() as sess:
    saver = tf.train.Saver()
    saver.restore(sess, "/tmp/model.ckpt")

Nawet jeśli nie ma pliku o nazwie model.ckpt, nadal odwołujesz się do zapisanego punktu kontrolnego o tej nazwie podczas jego przywracania. Z saver.pykodu źródłowego :

Użytkownicy muszą tylko wchodzić w interakcje z przedrostkiem określonym przez użytkownika ... zamiast z jakąkolwiek fizyczną nazwą ścieżki.


1
więc .index i .data nie są używane? Kiedy są używane te 2 pliki?
ajfbiw.s

26
@ ajfbiw.s .meta przechowuje strukturę wykresu, .data przechowuje wartości każdej zmiennej na wykresie, .index identyfikuje checkpiont. W powyższym przykładzie: import_meta_graph używa .meta, a saver.restore używa .data i .index
TK Bartel

Rozumiem. Dzięki.
ajfbiw.s

1
Czy jest szansa, że ​​zapisałeś model w innej wersji TensorFlow niż ta, której używasz do załadowania? ( github.com/tensorflow/tensorflow/issues/5639 )
TK Bartel

5
Czy ktoś wie, co to 00000i 00001liczby oznaczają? w variables.data-?????-of-?????aktach
Ivan Talalaev

55
  • plik meta : opisuje zapisaną strukturę wykresu, zawiera GraphDef, SaverDef i tak dalej; następnie zastosuje tf.train.import_meta_graph('/tmp/model.ckpt.meta'), przywróci Saveri Graph.

  • plik indeksu : jest to niezmienna tabela typu string-string (tensorflow :: table :: Table). Każdy klucz jest nazwą tensora, a jego wartością jest serializowany BundleEntryProto. Każdy BundleEntryProto opisuje metadane tensora: który z plików "danych" zawiera zawartość tensora, przesunięcie w tym pliku, sumę kontrolną, niektóre dane pomocnicze itp.

  • plik danych : jest to kolekcja TensorBundle, zapisz wartości wszystkich zmiennych.


Mam plik pb, który mam do klasyfikacji obrazu. Czy mogę go używać do klasyfikacji wideo w czasie rzeczywistym?

Czy możesz mi dać znać, używając Keras 2, jak wczytać model, jeśli jest zapisany jako 3 pliki?
rajkiran

5

Przywracam wytrenowane osadzanie słów z samouczka tensorflow Word2Vec.

W przypadku, gdy utworzyłeś wiele punktów kontrolnych:

np. utworzone pliki wyglądają tak

model.ckpt-55695.data-00000-of-00001

model.ckpt-55695.index

model.ckpt-55695.meta

Spróbuj tego

def restore_session(self, session):
   saver = tf.train.import_meta_graph('./tmp/model.ckpt-55695.meta')
   saver.restore(session, './tmp/model.ckpt-55695')

podczas wywoływania restore_session ():

def test_word2vec():
   opts = Options()    
   with tf.Graph().as_default(), tf.Session() as session:
       with tf.device("/cpu:0"):            
           model = Word2Vec(opts, session)
           model.restore_session(session)
           model.get_embedding("assistance")

Co to znaczy „00000-of-00001” w „model.ckpt-55695.data-00000-of-00001”?
hafiz031

0

Jeśli na przykład przeszkoliłeś CNN z porzuceniem, możesz zrobić to:

def predict(image, model_name):
    """
    image -> single image, (width, height, channels)
    model_name -> model file that was saved without any extensions
    """
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph('./' + model_name + '.meta')
        saver.restore(sess, './' + model_name)
        # Substitute 'logits' with your model
        prediction = tf.argmax(logits, 1)
        # 'x' is what you defined it to be. In my case it is a batch of RGB images, that's why I add the extra dimension
        return prediction.eval(feed_dict={x: image[np.newaxis,:,:,:], keep_prob_dnn: 1.0})
Korzystając z naszej strony potwierdzasz, że przeczytałeś(-aś) i rozumiesz nasze zasady używania plików cookie i zasady ochrony prywatności.
Licensed under cc by-sa 3.0 with attribution required.