TensorFlow zapisuje do / ładuje wykres z pliku


98

Z tego, co do tej pory zebrałem, istnieje kilka różnych sposobów zrzucenia wykresu TensorFlow do pliku, a następnie załadowania go do innego programu, ale nie udało mi się znaleźć jasnych przykładów / informacji na temat ich działania. To, co już wiem, to:

  1. Zapisz zmienne modelu w pliku punktu kontrolnego (.ckpt) za pomocą a tf.train.Saver()i przywróć je później ( źródło )
  2. Zapisz model do pliku .pb i załaduj go z powrotem za pomocą tf.train.write_graph()i tf.import_graph_def()( źródło )
  3. Załaduj model z pliku .pb, przetrenuj go i wrzuć do nowego pliku .pb za pomocą Bazel ( źródło )
  4. Zablokuj wykres, aby zapisać wykres i wagi razem ( źródło )
  5. Służy as_graph_def()do zapisywania modelu, a dla wag / zmiennych mapowania ich na stałe ( źródło )

Nie udało mi się jednak wyjaśnić kilku pytań dotyczących tych różnych metod:

  1. Jeśli chodzi o pliki punktów kontrolnych, czy zapisują one tylko wyuczone wagi modelu? Czy pliki punktów kontrolnych mogą zostać załadowane do nowego programu i użyte do uruchomienia modelu, czy mogą po prostu służyć jako sposób na zapisanie wag w modelu w określonym czasie / etapie?
  2. W związku z tym tf.train.write_graph(), czy wagi / zmienne również są zapisane?
  3. Jeśli chodzi o Bazel, czy może on zapisywać do / ładować z plików .pb tylko w celu ponownego przeszkolenia? Czy istnieje proste polecenie Bazel, które służy tylko do zrzucenia wykresu do pliku .pb?
  4. Jeśli chodzi o zamrażanie, czy zamrożony wykres można załadować za pomocą tf.import_graph_def()?
  5. Wersja demonstracyjna systemu Android dla TensorFlow ładuje się w modelu Inception Google z pliku .pb. Gdybym chciał zastąpić własny plik .pb, jak bym to zrobił? Czy musiałbym zmienić kod / metody natywne?
  6. Jaka jest właściwie różnica między wszystkimi tymi metodami? Albo szerzej, jaka jest różnica między as_graph_def()/.ckpt/.pb?

Krótko mówiąc, szukam metody zapisywania zarówno wykresu (jak w przypadku różnych operacji itp.), Jak i jego wag / zmiennych do pliku, którego można następnie użyć do załadowania wykresu i wag do innego programu do użytku (niekoniecznie kontynuowanie / przekwalifikowanie).

Dokumentacja na ten temat nie jest prosta, więc wszelkie odpowiedzi / informacje byłyby bardzo mile widziane.


2
Najnowszym / najbardziej kompletnym API jest metagraf, który daje możliwość zapisania wszystkich trzech naraz - 1) wykres 2) wartości parametrów 3) kolekcje: tensorflow.org/versions/r0.10/how_tos/meta_graph/ index.html
Yaroslav Bulatov

Odpowiedzi:


80

Istnieje wiele sposobów podejścia do problemu zapisywania modelu w TensorFlow, co może być nieco zagmatwane. Biorąc po kolei każde z pytań podrzędnych:

  1. Pliki punktów kontrolnych (np produkowane przez wywołanie saver.save()na tf.train.Saverobiekcie) zawierają tylko ciężary oraz wszelkie inne zmienne zdefiniowane w tym samym programie. Aby użyć ich w innym programie, należy odtworzyć powiązaną strukturę wykresu (np. Uruchamiając kod w celu jego ponownego zbudowania lub wywołując tf.import_graph_def()), która powie TensorFlow, co ma zrobić z tymi wagami. Zwróć uwagę, że wywołanie saver.save()również tworzy plik zawierający a MetaGraphDef, który zawiera wykres i szczegóły dotyczące powiązania wag z punktu kontrolnego z tym wykresem. Więcej informacji znajdziesz w samouczku .

  2. tf.train.write_graph()zapisuje tylko strukturę grafu; nie ciężary.

  3. Bazel nie jest związany z czytaniem ani pisaniem wykresów TensorFlow. (Być może źle zrozumiałem twoje pytanie: możesz to wyjaśnić w komentarzu.)

  4. Zamrożony wykres można załadować za pomocą tf.import_graph_def(). W takim przypadku wagi są (zazwyczaj) osadzone na wykresie, więc nie ma potrzeby ładowania osobnego punktu kontrolnego.

  5. Główną zmianą byłoby zaktualizowanie nazw tensorów, które są wprowadzane do modelu, oraz nazw tensorów, które są pobierane z modelu. W wersji demonstracyjnej TensorFlow dla systemu Android odpowiadałoby to ciągom znaków inputNamei, outputNamektóre są przekazywane do TensorFlowClassifier.initializeTensorFlow().

  6. Jest GraphDefto struktura programu, która zazwyczaj nie zmienia się w trakcie procesu szkolenia. Punkt kontrolny to migawka stanu procesu szkolenia, który zwykle zmienia się na każdym etapie procesu szkolenia. W rezultacie TensorFlow używa różnych formatów przechowywania tych typów danych, a niskopoziomowy interfejs API zapewnia różne sposoby ich zapisywania i ładowania. Biblioteki wyższego poziomu, takie jak MetaGraphDefbiblioteki, Keras i skflow, opierają się na tych mechanizmach, aby zapewnić wygodniejsze sposoby zapisywania i przywracania całego modelu.


Czy to oznacza, że dokumentacja API C ++ kłamie, kiedy mówi, że można załadować zapisany wykres, tf.train.write_graph()a następnie go wykonać?
mnicky

2
Dokumentacja API C ++ nie kłamie, ale brakuje w niej kilku szczegółów. Najważniejszym szczegółem jest to, że oprócz GraphDefzapisanych przez tf.train.write_graph(), musisz również zapamiętać nazwy tensorów, które chcesz zasilić i pobrać podczas wykonywania wykresu (pozycja 5 powyżej).
pan

@mrry: Próbowałem użyć przykładu tensorflows DeepDream. ale wygląda na to, że potrzebuje wstępnie wytrenowanych modeli w formacie pb! Uruchomiłem przykład Cifar10, ale tworzy on tylko punkty kontrolne! Nie mogłem znaleźć żadnych plików PB ani czegokolwiek! jak mogę przekonwertować moje punkty kontrolne na format pb, którego używa przykład deepdream?
Rika

2
@ Coderx7 Naprawdę myślę, że nie można przekonwertować .ckpt na .pb, ponieważ punkt kontrolny zawiera tylko wagi i zmienne i nie wie nic o strukturze wykresu
davidivad

1
czy istnieje prosty kod do załadowania pliku .pb, a następnie uruchomienia go?
Kong

1

Możesz wypróbować następujący kod:

with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)
Korzystając z naszej strony potwierdzasz, że przeczytałeś(-aś) i rozumiesz nasze zasady używania plików cookie i zasady ochrony prywatności.
Licensed under cc by-sa 3.0 with attribution required.