Tensorflow: jak zapisać / przywrócić model?


552

Po wytrenowaniu modelu w Tensorflow:

  1. Jak zapisać wyszkolony model?
  2. Jak później przywrócić ten zapisany model?

Czy udało Ci się przywrócić zmienne użyte w modelu początkowym? Próbuję również dokładnie tego samego problemu, ale nie jestem w stanie napisać zestawu zmiennych, które były używane podczas szkolenia modelu początkowego (którego mam plik ckpt)
exAres

Nie próbowałem z modelem początkowym. Czy masz strukturę sieci modelu z jego nazwami? Musisz zreplikować sieć, a następnie załadować wagi i obciążenia (plik ckpt), jak wyjaśnia Ryan. Może coś się zmieniło od 15 listopada i jest teraz prostsze podejście, nie jestem pewien
mathetes

Och w porządku. Wcześniej załadowałem inne wstępnie wyszkolone modele przepływu tensorowego, ale szukałem zmiennych specyfikacji modelu początkowego. Dzięki.
exAres

1
Jeśli przywrócisz, aby kontynuować trening, użyj punktów kontrolnych Oszczędzanie. Jeśli zapiszesz model w celu wykonania odwołania, tylko interfejsy API SavedModel tensorflow.
HY G

Również jeśli używasz LSTM, będziesz miał mapę od ciągu do listy znaków, pamiętaj, aby zapisać i załadować tę listę w tej samej kolejności! Nie jest to objęte zapisywaniem wag modelu i sieci wykresów modelu i sprawi, że będzie wyglądać, jakby Twój model nie został załadowany podczas zmiany sesji lub zmian danych.
devssh

Odpowiedzi:


119

Dokumenty

Z dokumentów:

Zapisać

# Create some variables.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)

inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  inc_v1.op.run()
  dec_v2.op.run()
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in path: %s" % save_path)

Przywracać

tf.reset_default_graph()

# Create some variables.
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[5])

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # Check the values of the variables
  print("v1 : %s" % v1.eval())
  print("v2 : %s" % v2.eval())

Przepływ Tensor 2

To wciąż wersja beta, więc odradzam na razie. Jeśli nadal chcesz iść tą drogą, tutaj jesttf.saved_model instrukcją użytkowania

Przepływ Tensor <2

simple_save

Wiele dobrych odpowiedzi, dla kompletności dodam moje 2 centy: simple_save . Również przykład samodzielnego kodu przy użyciutf.data.Dataset API.

Python 3; Przepływ Tensor 1.14

import tensorflow as tf
from tensorflow.saved_model import tag_constants

with tf.Graph().as_default():
    with tf.Session() as sess:
        ...

        # Saving
        inputs = {
            "batch_size_placeholder": batch_size_placeholder,
            "features_placeholder": features_placeholder,
            "labels_placeholder": labels_placeholder,
        }
        outputs = {"prediction": model_output}
        tf.saved_model.simple_save(
            sess, 'path/to/your/location/', inputs, outputs
        )

Przywracanie:

graph = tf.Graph()
with restored_graph.as_default():
    with tf.Session() as sess:
        tf.saved_model.loader.load(
            sess,
            [tag_constants.SERVING],
            'path/to/your/location/',
        )
        batch_size_placeholder = graph.get_tensor_by_name('batch_size_placeholder:0')
        features_placeholder = graph.get_tensor_by_name('features_placeholder:0')
        labels_placeholder = graph.get_tensor_by_name('labels_placeholder:0')
        prediction = restored_graph.get_tensor_by_name('dense/BiasAdd:0')

        sess.run(prediction, feed_dict={
            batch_size_placeholder: some_value,
            features_placeholder: some_other_value,
            labels_placeholder: another_value
        })

Przykład samodzielny

Oryginalny post na blogu

Poniższy kod generuje losowe dane na potrzeby demonstracji.

  1. Zaczynamy od utworzenia symboli zastępczych. Będą przechowywać dane w czasie wykonywania. Z nich tworzymy, Dataseta następnie jego Iterator. Otrzymujemy wygenerowany tensor iteratora, tzwinput_tensor który posłuży jako dane wejściowe do naszego modelu.
  2. Sam model jest zbudowany z input_tensor : dwukierunkowego RNN opartego na GRU, a następnie gęstego klasyfikatora. Bo czemu nie.
  3. Strata jest softmax_cross_entropy_with_logitszoptymalizowana Adam. Po 2 epokach (po 2 partie każda) zapisujemy „wytrenowany” model za pomocą tf.saved_model.simple_save. Jeśli uruchomisz kod w obecnej postaci, model zostanie zapisany w folderze o nazwiesimple/ w obecnej w bieżącym katalogu roboczym.
  4. Na nowym wykresie przywracamy zapisany model za pomocą tf.saved_model.loader.load. Chwytamy symbole zastępcze i logi za pomocą graph.get_tensor_by_namei Iteratorinicjujemy operację za pomocągraph.get_operation_by_name .
  5. Na koniec przeprowadzamy wnioskowanie dla obu partii w zestawie danych i sprawdzamy, czy zarówno zapisany, jak i przywrócony model dają te same wartości. Oni robią!

Kod:

import os
import shutil
import numpy as np
import tensorflow as tf
from tensorflow.python.saved_model import tag_constants


def model(graph, input_tensor):
    """Create the model which consists of
    a bidirectional rnn (GRU(10)) followed by a dense classifier

    Args:
        graph (tf.Graph): Tensors' graph
        input_tensor (tf.Tensor): Tensor fed as input to the model

    Returns:
        tf.Tensor: the model's output layer Tensor
    """
    cell = tf.nn.rnn_cell.GRUCell(10)
    with graph.as_default():
        ((fw_outputs, bw_outputs), (fw_state, bw_state)) = tf.nn.bidirectional_dynamic_rnn(
            cell_fw=cell,
            cell_bw=cell,
            inputs=input_tensor,
            sequence_length=[10] * 32,
            dtype=tf.float32,
            swap_memory=True,
            scope=None)
        outputs = tf.concat((fw_outputs, bw_outputs), 2)
        mean = tf.reduce_mean(outputs, axis=1)
        dense = tf.layers.dense(mean, 5, activation=None)

        return dense


def get_opt_op(graph, logits, labels_tensor):
    """Create optimization operation from model's logits and labels

    Args:
        graph (tf.Graph): Tensors' graph
        logits (tf.Tensor): The model's output without activation
        labels_tensor (tf.Tensor): Target labels

    Returns:
        tf.Operation: the operation performing a stem of Adam optimizer
    """
    with graph.as_default():
        with tf.variable_scope('loss'):
            loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
                    logits=logits, labels=labels_tensor, name='xent'),
                    name="mean-xent"
                    )
        with tf.variable_scope('optimizer'):
            opt_op = tf.train.AdamOptimizer(1e-2).minimize(loss)
        return opt_op


if __name__ == '__main__':
    # Set random seed for reproducibility
    # and create synthetic data
    np.random.seed(0)
    features = np.random.randn(64, 10, 30)
    labels = np.eye(5)[np.random.randint(0, 5, (64,))]

    graph1 = tf.Graph()
    with graph1.as_default():
        # Random seed for reproducibility
        tf.set_random_seed(0)
        # Placeholders
        batch_size_ph = tf.placeholder(tf.int64, name='batch_size_ph')
        features_data_ph = tf.placeholder(tf.float32, [None, None, 30], 'features_data_ph')
        labels_data_ph = tf.placeholder(tf.int32, [None, 5], 'labels_data_ph')
        # Dataset
        dataset = tf.data.Dataset.from_tensor_slices((features_data_ph, labels_data_ph))
        dataset = dataset.batch(batch_size_ph)
        iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
        dataset_init_op = iterator.make_initializer(dataset, name='dataset_init')
        input_tensor, labels_tensor = iterator.get_next()

        # Model
        logits = model(graph1, input_tensor)
        # Optimization
        opt_op = get_opt_op(graph1, logits, labels_tensor)

        with tf.Session(graph=graph1) as sess:
            # Initialize variables
            tf.global_variables_initializer().run(session=sess)
            for epoch in range(3):
                batch = 0
                # Initialize dataset (could feed epochs in Dataset.repeat(epochs))
                sess.run(
                    dataset_init_op,
                    feed_dict={
                        features_data_ph: features,
                        labels_data_ph: labels,
                        batch_size_ph: 32
                    })
                values = []
                while True:
                    try:
                        if epoch < 2:
                            # Training
                            _, value = sess.run([opt_op, logits])
                            print('Epoch {}, batch {} | Sample value: {}'.format(epoch, batch, value[0]))
                            batch += 1
                        else:
                            # Final inference
                            values.append(sess.run(logits))
                            print('Epoch {}, batch {} | Final inference | Sample value: {}'.format(epoch, batch, values[-1][0]))
                            batch += 1
                    except tf.errors.OutOfRangeError:
                        break
            # Save model state
            print('\nSaving...')
            cwd = os.getcwd()
            path = os.path.join(cwd, 'simple')
            shutil.rmtree(path, ignore_errors=True)
            inputs_dict = {
                "batch_size_ph": batch_size_ph,
                "features_data_ph": features_data_ph,
                "labels_data_ph": labels_data_ph
            }
            outputs_dict = {
                "logits": logits
            }
            tf.saved_model.simple_save(
                sess, path, inputs_dict, outputs_dict
            )
            print('Ok')
    # Restoring
    graph2 = tf.Graph()
    with graph2.as_default():
        with tf.Session(graph=graph2) as sess:
            # Restore saved values
            print('\nRestoring...')
            tf.saved_model.loader.load(
                sess,
                [tag_constants.SERVING],
                path
            )
            print('Ok')
            # Get restored placeholders
            labels_data_ph = graph2.get_tensor_by_name('labels_data_ph:0')
            features_data_ph = graph2.get_tensor_by_name('features_data_ph:0')
            batch_size_ph = graph2.get_tensor_by_name('batch_size_ph:0')
            # Get restored model output
            restored_logits = graph2.get_tensor_by_name('dense/BiasAdd:0')
            # Get dataset initializing operation
            dataset_init_op = graph2.get_operation_by_name('dataset_init')

            # Initialize restored dataset
            sess.run(
                dataset_init_op,
                feed_dict={
                    features_data_ph: features,
                    labels_data_ph: labels,
                    batch_size_ph: 32
                }

            )
            # Compute inference for both batches in dataset
            restored_values = []
            for i in range(2):
                restored_values.append(sess.run(restored_logits))
                print('Restored values: ', restored_values[i][0])

    # Check if original inference and restored inference are equal
    valid = all((v == rv).all() for v, rv in zip(values, restored_values))
    print('\nInferences match: ', valid)

Spowoduje to wydrukowanie:

$ python3 save_and_restore.py

Epoch 0, batch 0 | Sample value: [-0.13851789 -0.3087595   0.12804556  0.20013677 -0.08229901]
Epoch 0, batch 1 | Sample value: [-0.00555491 -0.04339041 -0.05111827 -0.2480045  -0.00107776]
Epoch 1, batch 0 | Sample value: [-0.19321944 -0.2104792  -0.00602257  0.07465433  0.11674127]
Epoch 1, batch 1 | Sample value: [-0.05275984  0.05981954 -0.15913513 -0.3244143   0.10673307]
Epoch 2, batch 0 | Final inference | Sample value: [-0.26331693 -0.13013336 -0.12553    -0.04276478  0.2933622 ]
Epoch 2, batch 1 | Final inference | Sample value: [-0.07730117  0.11119192 -0.20817074 -0.35660955  0.16990358]

Saving...
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: b'/some/path/simple/saved_model.pb'
Ok

Restoring...
INFO:tensorflow:Restoring parameters from b'/some/path/simple/variables/variables'
Ok
Restored values:  [-0.26331693 -0.13013336 -0.12553    -0.04276478  0.2933622 ]
Restored values:  [-0.07730117  0.11119192 -0.20817074 -0.35660955  0.16990358]

Inferences match:  True

1
Jestem początkującym i potrzebuję więcej wyjaśnień ...: Jeśli mam model CNN, czy powinienem przechowywać tylko 1. input_placeholder 2. labels_placeholder i 3. output_of_cnn? Czy cały półprodukt tf.contrib.layers?
Pada

2
Wykres został całkowicie przywrócony. Możesz to sprawdzić [n.name for n in graph2.as_graph_def().node]. Jak głosi dokumentacja, proste zapisywanie ma na celu uproszczenie interakcji z obsługą tensorflow, o to właśnie chodzi w argumentach; inne zmienne są jednak nadal przywracane, w przeciwnym razie wnioskowanie nie nastąpiłoby. Po prostu chwyć swoje zmienne zainteresowania, tak jak w przykładzie. Sprawdź dokumentację
ted

@ted kiedy miałbym użyć tf.saved_model.simple_save vs. tf.train.Saver ()? Z mojej intuicji korzystałbym z tf.train.Saver () podczas treningu i do przechowywania różnych momentów w czasie. Używałbym tf.saved_model.simple_save, gdy szkolenie jest gotowe do użycia w produkcji. (To samo zapytałem również w komentarzu tutaj )
loco.loop,

1
Chyba fajnie, ale czy działa również z modelami w trybie Eager i tfe.Saver?
Geoffrey Anderson

1
bez global_stepargumentu, jeśli przestaniesz, a następnie spróbujesz ponownie rozpocząć trening, to pomyślisz, że jesteś krok do przodu. Przynajmniej spieszy twoje wizualizacje tensorboardów
Monica Heddneck

252

Poprawiam swoją odpowiedź, aby dodać więcej szczegółów dotyczących zapisywania i przywracania modeli.

W (i po) wersji Tensorflow 0.11 :

Zapisz model:

import tensorflow as tf

#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}

#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

#Create a saver object which will save all the variables
saver = tf.train.Saver()

#Run the operation by feeding input
print sess.run(w4,feed_dict)
#Prints 24 which is sum of (w1+w2)*b1 

#Now, save the graph
saver.save(sess, 'my_test_model',global_step=1000)

Przywróć model:

import tensorflow as tf

sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))


# Access saved Variables directly
print(sess.run('bias:0'))
# This will print 2, which is the value of bias that we saved


# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated 

To i niektóre bardziej zaawansowane przypadki użycia zostały tutaj bardzo dobrze wyjaśnione.

Szybki kompletny samouczek do zapisywania i przywracania modeli Tensorflow


3
+1 za to # Dostęp do zapisanych zmiennych Drukowane bezpośrednio (sess.run ('bias: 0')) # Spowoduje to wydrukowanie 2, czyli wartości uprzednio zapisanego. Bardzo pomaga w debugowaniu, aby sprawdzić, czy model jest poprawnie załadowany. zmienne można uzyskać za pomocą „All_varaibles = tf.get_collection (tf.GraphKeys.GLOBAL_VARIABLES”. Również „sess.run (tf.global_variables_initializer ())” musi być przed przywróceniem.
LGG

1
Czy na pewno musimy ponownie uruchomić global_variables_initializer? Przywróciłem mój wykres za pomocą global_variable_initialization i daje mi to za każdym razem inne dane wyjściowe dla tych samych danych. Skomentowałem więc inicjalizację i właśnie przywróciłem wykres, zmienną wejściową i operacje, a teraz działa dobrze.
Aditya Shinde

@AdityaShinde Nie rozumiem, dlaczego zawsze otrzymuję różne wartości za każdym razem. I nie uwzględniłem kroku inicjalizacji zmiennej w celu przywrócenia. Używam własnego kodu btw.
Chaine

@AdityaShinde: nie potrzebujesz init op, ponieważ wartości są już inicjalizowane przez funkcję przywracania, więc ją usunąłeś. Nie jestem jednak pewien, dlaczego otrzymałeś różne dane wyjściowe za pomocą init op.
sankit

5
@sankit Kiedy przywracasz tensory, dlaczego dodajesz :0do nazw?
Sahar Rabinoviz

177

W (i późniejszej) wersji TensorFlow 0.11.0RC1 możesz zapisać i przywrócić swój model bezpośrednio, dzwoniąc tf.train.export_meta_graphi tf.train.import_meta_graphzgodnie z https://www.tensorflow.org/programmers_guide/meta_graph .

Zapisz model

w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
tf.add_to_collection('vars', w1)
tf.add_to_collection('vars', w2)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')
# `save` method will call `export_meta_graph` implicitly.
# you will get saved graph files:my-model.meta

Przywróć model

sess = tf.Session()
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
    v_ = sess.run(v)
    print(v_)

4
jak załadować zmienne z zapisanego modelu? Jak skopiować wartości do innej zmiennej?
neel

9
Nie mogę uruchomić tego kodu. Model zostanie zapisany, ale nie mogę go przywrócić. Daje mi ten błąd. <built-in function TF_Run> returned a result with an error set
Saad Qureshi

2
Kiedy po przywróceniu uzyskuję dostęp do zmiennych, jak pokazano powyżej, działa. Ale nie mogę uzyskać zmiennych bardziej bezpośrednio przy użyciu, tf.get_variable_scope().reuse_variables()a następnie var = tf.get_variable("varname"). Daje mi to błąd: „Błąd wartości: Zmienna nazwa zmiennej nie istnieje lub nie została utworzona za pomocą tf.get_variable ().” Dlaczego? Czy nie powinno to być możliwe?
Johann Petrak

4
Działa to dobrze tylko w przypadku zmiennych, ale w jaki sposób można uzyskać dostęp do symbolu zastępczego i wartości do niego po przywróceniu wykresu?
kbrose

11
To pokazuje tylko, jak przywrócić zmienne. Jak przywrócić cały model i przetestować go na nowych danych bez ponownego definiowania sieci?
Chaine

127

Dla wersji TensorFlow <0.11.0RC1:

Zapisane punkty kontrolne zawierają wartości dla Variable s w modelu, a nie sam model / wykres, co oznacza, że ​​wykres powinien być taki sam podczas przywracania punktu kontrolnego.

Oto przykład regresji liniowej, w której istnieje pętla treningowa, która zapisuje zmienne punkty kontrolne, oraz sekcja oceny, która przywróci zmienne zapisane w poprzednim przebiegu i obliczy prognozy. Oczywiście możesz także przywrócić zmienne i kontynuować trening, jeśli chcesz.

x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)

w = tf.Variable(tf.zeros([1, 1], dtype=tf.float32))
b = tf.Variable(tf.ones([1, 1], dtype=tf.float32))
y_hat = tf.add(b, tf.matmul(x, w))

...more setup for optimization and what not...

saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    if FLAGS.train:
        for i in xrange(FLAGS.training_steps):
            ...training loop...
            if (i + 1) % FLAGS.checkpoint_steps == 0:
                saver.save(sess, FLAGS.checkpoint_dir + 'model.ckpt',
                           global_step=i+1)
    else:
        # Here's where you're restoring the variables w and b.
        # Note that the graph is exactly as it was when the variables were
        # saved in a prior training run.
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            ...no checkpoint found...

        # Now you can run the model to get predictions
        batch_x = ...load some data...
        predictions = sess.run(y_hat, feed_dict={x: batch_x})

Oto dokumenty dotyczące Variables, które obejmują zapisywanie i przywracanie. A oto dokumenty dla Saver.


1
FLAGI są zdefiniowane przez użytkownika. Oto przykład ich zdefiniowania: github.com/tensorflow/tensorflow/blob/master/tensorflow/...
Ryan Sepassi

w jakim formacie batch_xmusi być? Dwójkowy? Tablica Numpy?
pepe,

@pepe Numpy Arrary powinno być w porządku. A typ elementu powinien odpowiadać typowi symbolu zastępczego. [link] tensorflow.org/versions/r0.9/api_docs/python/…
Donny

FLAGI daje błąd undefined. Czy możesz mi powiedzieć, która z definicji FLAGS dla tego kodu. @RyanSepassi
Muhammad Hannan,

Żeby było wyraźne: Najnowsze wersje Tensorflow nie pozwalają przechowywać modelu / wykres. [Nie było dla mnie jasne, które aspekty odpowiedzi dotyczą ograniczenia <0,11. Biorąc pod uwagę dużą liczbę pozytywnych opinii, kusiło mnie, aby uwierzyć, że to ogólne stwierdzenie jest nadal prawdziwe w przypadku najnowszych wersji.]
bluenote10 18.04.17

78

Moje środowisko: Python 3.6, Tensorflow 1.3.0

Chociaż istnieje wiele rozwiązań, większość z nich jest oparta tf.train.Saver. Kiedy załadować .ckptzapisany przez Savermusimy albo przedefiniować sieć tensorflow lub użyć trochę dziwne i ciężko pamiętał nazwę, na przykład 'placehold_0:0', 'dense/Adam/Weight:0'. Tutaj polecam skorzystać z tf.saved_modeljednego najprostszego przykładu podanego poniżej, aby dowiedzieć się więcej na temat obsługi modelu TensorFlow :

Zapisz model:

import tensorflow as tf

# define the tensorflow network and do some trains
x = tf.placeholder("float", name="x")
w = tf.Variable(2.0, name="w")
b = tf.Variable(0.0, name="bias")

h = tf.multiply(x, w)
y = tf.add(h, b, name="y")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# save the model
export_path =  './savedmodel'
builder = tf.saved_model.builder.SavedModelBuilder(export_path)

tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
tensor_info_y = tf.saved_model.utils.build_tensor_info(y)

prediction_signature = (
  tf.saved_model.signature_def_utils.build_signature_def(
      inputs={'x_input': tensor_info_x},
      outputs={'y_output': tensor_info_y},
      method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

builder.add_meta_graph_and_variables(
  sess, [tf.saved_model.tag_constants.SERVING],
  signature_def_map={
      tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
          prediction_signature 
  },
  )
builder.save()

Załaduj model:

import tensorflow as tf
sess=tf.Session() 
signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
input_key = 'x_input'
output_key = 'y_output'

export_path =  './savedmodel'
meta_graph_def = tf.saved_model.loader.load(
           sess,
          [tf.saved_model.tag_constants.SERVING],
          export_path)
signature = meta_graph_def.signature_def

x_tensor_name = signature[signature_key].inputs[input_key].name
y_tensor_name = signature[signature_key].outputs[output_key].name

x = sess.graph.get_tensor_by_name(x_tensor_name)
y = sess.graph.get_tensor_by_name(y_tensor_name)

y_out = sess.run(y, {x: 3.0})

4
+1 za świetny przykład interfejsu API SavedModel. Chciałbym jednak, aby sekcja Zapisz model pokazała pętlę treningową, taką jak odpowiedź Ryana Sepassi! Zdaję sobie sprawę, że to stare pytanie, ale ta odpowiedź jest jednym z niewielu (i cennych) przykładów SavedModel, które znalazłem w Google.
Dylan F

@Tom To świetna odpowiedź - tylko jedna skierowana na nowy model SavedModel. Czy mógłbyś rzucić okiem na to pytanie SavedModel? stackoverflow.com/questions/48540744/…
bluesummers

Teraz spraw, by wszystko działało poprawnie z modelami TF Eager. Google doradzało w prezentacji 2018, aby każdy mógł uciec od kodu wykresu TF.
Geoffrey Anderson

55

Model składa się z dwóch części, definicji modelu, zapisanej przez Supervisorjak graph.pbtxtw katalogu modelu oraz wartości liczbowych tensorów, zapisanych w plikach punktów kontrolnych, takich jakmodel.ckpt-1003418 .

Definicję modelu można przywrócić za pomocą tf.import_graph_def, a wagi przywrócić za pomocąSaver .

Jednakże Saverwykorzystuje specjalną kolekcję listę zmiennych, które jest dołączone do modelu Graph gospodarstwa, a ta kolekcja nie jest inicjowany za pomocą import_graph_def, więc nie można korzystać z dwóch razem w tej chwili (jest na naszej mapie drogowej do poprawki). Na razie musisz użyć podejścia Ryana Sepassi - ręcznie skonstruuj wykres z identycznymi nazwami węzłów i użyj, Saveraby załadować do niego wagi.

(Alternatywnie możesz zhakować go, używając import_graph_def, tworząc ręcznie zmienne i używając tf.add_to_collection(tf.GraphKeys.VARIABLES, variable)dla każdej zmiennej, a następnie używając Saver)


W przykładzie classify_image.py, który używa inceptionv3, ładowany jest tylko graphdef. Czy to oznacza, że ​​teraz GraphDef zawiera również zmienną?
jrabary

1
@jrabary Model prawdopodobnie został zamrożony .
Eric Platon

1
Hej, jestem nowy w tensorflow i mam problem z zapisaniem mojego modelu. Byłbym bardzo wdzięczny, gdybyś mógł mi pomóc stackoverflow.com/questions/48083474/...
Ruchir Baronia

39

Możesz także wybrać ten łatwiejszy sposób.

Krok 1: zainicjuj wszystkie zmienne

W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1")
B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1")

Similarly, W2, B2, W3, .....

Krok 2: zapisz sesję w modelu Saveri zapisz ją

model_saver = tf.train.Saver()

# Train the model and save it in the end
model_saver.save(session, "saved_models/CNN_New.ckpt")

Krok 3: przywróć model

with tf.Session(graph=graph_cnn) as session:
    model_saver.restore(session, "saved_models/CNN_New.ckpt")
    print("Model restored.") 
    print('Initialized')

Krok 4: sprawdź swoją zmienną

W1 = session.run(W1)
print(W1)

Korzystając z innej instancji Pythona, użyj

with tf.Session() as sess:
    # Restore latest checkpoint
    saver.restore(sess, tf.train.latest_checkpoint('saved_model/.'))

    # Initalize the variables
    sess.run(tf.global_variables_initializer())

    # Get default graph (supply your custom graph if you have one)
    graph = tf.get_default_graph()

    # It will give tensor object
    W1 = graph.get_tensor_by_name('W1:0')

    # To get the value (numpy array)
    W1_value = session.run(W1)

Cześć. Jak mogę zapisać model po przypuszczeniu 3000 iteracji, podobnych do Caffe. Dowiedziałem się, że tensorflow zapisuje tylko ostatnie modele, mimo że łączę liczbę iteracji z modelem, aby rozróżnić ją między wszystkimi iteracjami. Mam na myśli model_3000.ckpt, model_6000.ckpt, --- model_100000.ckpt. Czy możesz uprzejmie wyjaśnić, dlaczego nie zapisuje wszystkich, a zapisuje tylko 3 ostatnie iteracje.
khan


3
Czy istnieje metoda na zapisanie wszystkich zmiennych / nazw operacji na wykresie?
Moondra,

21

W większości przypadków tf.train.Savernajlepszym rozwiązaniem jest zapisywanie i przywracanie z dysku za pomocą :

... # build your model
saver = tf.train.Saver()

with tf.Session() as sess:
    ... # train the model
    saver.save(sess, "/tmp/my_great_model")

with tf.Session() as sess:
    saver.restore(sess, "/tmp/my_great_model")
    ... # use the model

Możesz także zapisać / przywrócić samą strukturę wykresu (szczegóły w dokumentacji MetaGraph ). Domyślnie Saverzapisuje strukturę wykresu w .metapliku. Możesz zadzwonić, import_meta_graph()aby go przywrócić. Przywraca strukturę wykresu i zwraca wartość Saver, której można użyć do przywrócenia stanu modelu:

saver = tf.train.import_meta_graph("/tmp/my_great_model.meta")

with tf.Session() as sess:
    saver.restore(sess, "/tmp/my_great_model")
    ... # use the model

Są jednak przypadki, w których potrzebujesz czegoś znacznie szybciej. Na przykład, jeśli wdrożysz wczesne zatrzymywanie, chcesz zapisywać punkty kontrolne za każdym razem, gdy model poprawia się podczas treningu (mierzony na podstawie zestawu sprawdzania poprawności), a następnie, jeśli nie ma postępu przez pewien czas, chcesz przywrócić najlepszy model. Jeśli zapiszesz model na dysku za każdym razem, gdy poprawi się, ogromnie spowolni trening. Sztuką jest zapisanie stanów zmiennych w pamięci , a następnie przywrócenie ich później:

... # build your model

# get a handle on the graph nodes we need to save/restore the model
graph = tf.get_default_graph()
gvars = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
assign_ops = [graph.get_operation_by_name(v.op.name + "/Assign") for v in gvars]
init_values = [assign_op.inputs[1] for assign_op in assign_ops]

with tf.Session() as sess:
    ... # train the model

    # when needed, save the model state to memory
    gvars_state = sess.run(gvars)

    # when needed, restore the model state
    feed_dict = {init_value: val
                 for init_value, val in zip(init_values, gvars_state)}
    sess.run(assign_ops, feed_dict=feed_dict)

Szybkie wyjaśnienie: po utworzeniu zmiennej XTensorFlow automatycznie tworzy operację przypisania w X/Assigncelu ustawienia wartości początkowej zmiennej. Zamiast tworzyć symbole zastępcze i dodatkowe operacje przypisania (co spowodowałoby bałagan na wykresie), po prostu używamy tych istniejących operacji przypisania. Pierwsze wejście każdej operacji przypisania jest odwołaniem do zmiennej, którą ma zainicjować, a drugie wejście ( assign_op.inputs[1]) jest wartością początkową. Aby więc ustawić dowolną wartość (zamiast wartości początkowej), musimy użyć a feed_dicti zastąpić wartość początkową. Tak, TensorFlow pozwala podać wartość dla dowolnej operacji, nie tylko dla symboli zastępczych, więc to działa dobrze.


Dziękuję za odpowiedź. Mam podobne pytanie, jak przekonwertować pojedynczy plik .ckpt na dwa pliki .index i .data (na przykład w przypadku wstępnie przeszkolonych modeli początkowych dostępnych na tf.slim). Moje pytanie jest tutaj: stackoverflow.com/questions/47762114/…
Amir

Hej, jestem nowy w tensorflow i mam problem z zapisaniem mojego modelu. Byłbym bardzo wdzięczny, gdybyś mógł mi pomóc stackoverflow.com/questions/48083474/...
Ruchir Baronia

17

Jak powiedział Jarosław, możesz zhakować przywracanie z graph_def i punktu kontrolnego, importując wykres, ręcznie tworząc zmienne, a następnie używając wygaszacza.

Zaimplementowałem to na własny użytek, więc pomyślałem, że podzielę się tutaj kodem.

Link: https://gist.github.com/nikitakit/6ef3b72be67b86cb7868

(Jest to oczywiście włamanie i nie ma gwarancji, że zapisane w ten sposób modele pozostaną czytelne w przyszłych wersjach TensorFlow.)


14

Jeśli jest to model zapisany wewnętrznie, po prostu określasz restauratora dla wszystkich zmiennych jako

restorer = tf.train.Saver(tf.all_variables())

i użyj go do przywrócenia zmiennych w bieżącej sesji:

restorer.restore(self._sess, model_file)

W przypadku modelu zewnętrznego musisz określić odwzorowanie jego nazw zmiennych na nazwy zmiennych. Możesz wyświetlić nazwy zmiennych modelu za pomocą polecenia

python /path/to/tensorflow/tensorflow/python/tools/inspect_checkpoint.py --file_name=/path/to/pretrained_model/model.ckpt

Skrypt inspect_checkpoint.py można znaleźć w folderze „./tensorflow/python/tools” źródła Tensorflow.

Aby określić mapowanie, możesz użyć mojego Tensorflow-Worklab , który zawiera zestaw klas i skryptów do trenowania i przekwalifikowywania różnych modeli. Zawiera przykład przekwalifikowania modeli ResNet, który znajduje się tutaj


all_variables()jest teraz przestarzałe
MiniQuark

Hej, jestem nowy w tensorflow i mam problem z zapisaniem mojego modelu. Byłbym bardzo wdzięczny, gdybyś mógł mi pomóc stackoverflow.com/questions/48083474/...
Ruchir Baronia

12

Oto moje proste rozwiązanie dwóch podstawowych przypadków różniących się od tego, czy chcesz załadować wykres z pliku, czy skompilować go w czasie wykonywania.

Ta odpowiedź dotyczy Tensorflow 0.12+ (w tym 1.0).

Przebudowa wykresu w kodzie

Oszczędność

graph = ... # build the graph
saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.save(sess, 'my-model')

Ładowanie

graph = ... # build the graph
saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    # now you can use the graph, continue training or whatever

Ładowanie również wykresu z pliku

Korzystając z tej techniki, upewnij się, że wszystkie twoje warstwy / zmienne mają jawnie ustawione unikalne nazwy.W przeciwnym razie Tensorflow sprawi, że nazwy będą unikalne, a zatem będą różne od nazw przechowywanych w pliku. W poprzedniej technice nie stanowi to problemu, ponieważ nazwy są „zniekształcane” w ten sam sposób zarówno podczas ładowania, jak i zapisywania.

Oszczędność

graph = ... # build the graph

for op in [ ... ]:  # operators you want to use after restoring the model
    tf.add_to_collection('ops_to_restore', op)

saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.save(sess, 'my-model')

Ładowanie

with ... as sess:  # your session object
    saver = tf.train.import_meta_graph('my-model.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    ops = tf.get_collection('ops_to_restore')  # here are your operators in the same order in which you saved them to the collection

-1 Rozpoczęcie odpowiedzi od odrzucenia „wszystkich innych odpowiedzi tutaj” jest nieco trudne. Powiedziałem jednak, że przegłosowałem z innych powodów: zdecydowanie powinieneś zapisać wszystkie zmienne globalne, a nie tylko zmienne możliwe do wyuczenia. Na przykład global_stepzmienna i średnie kroczące normalizacji partii są zmiennymi, których nie da się wyćwiczyć, ale zdecydowanie warto je zapisać. Ponadto należy wyraźniej odróżnić budowę wykresu od uruchomienia sesji, na przykład Saver(...).save()będzie tworzyć nowe węzły za każdym razem, gdy go uruchomisz. Prawdopodobnie nie to, czego chcesz. I jest więcej ...: /
MiniQuark

@MiniQuark ok, dziękuję za twoją opinię, zmienię odpowiedź zgodnie z twoimi sugestiami;)
Martin Pecka

10

Możesz także sprawdzić przykłady w TensorFlow / skflow , który oferuje metody savei restoremetody, które pomogą ci łatwo zarządzać swoimi modelami. Ma parametry, które możesz kontrolować, jak często chcesz tworzyć kopię zapasową modelu.


9

Jeśli używasz tf.train.MonitoredTrainingSession jako sesji domyślnej, nie musisz dodawać dodatkowego kodu, aby zapisywać / przywracać rzeczy. Po prostu przekaż nazwę kontrolną punktu kontrolnego konstruktorowi MonitoredTrainingSession, użyje haków sesji do ich obsługi.


użycie tf.train.Supervisor zajmie się tworzeniem dla ciebie takiej sesji i zapewni bardziej kompletne rozwiązanie.
Mark

1
@Mark tf.train.Supervisor jest wycofany
Changming Sun

Czy masz link do twierdzenia, że ​​Inspektor jest nieaktualny? Nie widziałem niczego, co by to wskazywało na to.
Mark


Dzięki za adres URL - sprawdziłem oryginalne źródło informacji i powiedziano mi, że prawdopodobnie będzie on dostępny do końca serii TF 1.x, ale potem nie ma żadnych gwarancji.
Mark

8

Wszystkie odpowiedzi tutaj są świetne, ale chcę dodać dwie rzeczy.

Po pierwsze, aby rozwinąć odpowiedź na temat @ user7505159, „./” może być ważne, aby dodać na początku przywracanej nazwy pliku.

Na przykład możesz zapisać wykres bez nazwy „./” w nazwie pliku w następujący sposób:

# Some graph defined up here with specific names

saver = tf.train.Saver()
save_file = 'model.ckpt'

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, save_file)

Ale w celu przywrócenia wykresu może być konieczne dodanie „./” do nazwy pliku:

# Same graph defined up here

saver = tf.train.Saver()
save_file = './' + 'model.ckpt' # String addition used for emphasis

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, save_file)

Nie zawsze będziesz potrzebować „./”, ale może to powodować problemy w zależności od środowiska i wersji TensorFlow.

Warto również wspomnieć, że sess.run(tf.global_variables_initializer())może to być ważne przed przywróceniem sesji.

Jeśli pojawia się błąd dotyczący niezainicjowanych zmiennych podczas próby przywrócenia zapisanej sesji, upewnij się, że podałeś ją sess.run(tf.global_variables_initializer())przed saver.restore(sess, save_file)wierszem. Może zaoszczędzić ci bólu głowy.


7

Jak opisano w numerze 6255 :

use '**./**model_name.ckpt'
saver.restore(sess,'./my_model_final.ckpt')

zamiast

saver.restore('my_model_final.ckpt')

7

Według nowej wersji Tensorflow tf.train.Checkpointpreferowanym sposobem zapisywania i przywracania modelu jest:

Checkpoint.saveoraz Checkpoint.restorepisać i odczytywać punkty kontrolne oparte na obiektach, w przeciwieństwie do tf.train.Saver, który zapisuje i odczytuje punkty kontrolne oparte na zmiennej.name. Obiektowe punkty kontrolne zapisują wykres zależności między obiektami Pythona (Warstwy, Optymalizatory, Zmienne itp.) Z nazwanymi krawędziami, a ten wykres służy do dopasowania zmiennych podczas przywracania punktu kontrolnego. Może być bardziej odporny na zmiany w programie Python i pomaga wspierać przywracanie przy tworzeniu zmiennych podczas wykonywania z niecierpliwością. Wolę tf.train.Checkpointponad tf.train.Saverdla nowego kodu .

Oto przykład:

import tensorflow as tf
import os

tf.enable_eager_execution()

checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
for _ in range(num_training_steps):
  optimizer.minimize( ... )  # Variables will be restored on creation.
status.assert_consumed()  # Optional sanity checks.
checkpoint.save(file_prefix=checkpoint_prefix)

Więcej informacji i przykład tutaj.


7

W przypadku tensorflow 2.0 jest to tak proste

# Save the model
model.save('path_to_my_model.h5')

Przywrócić:

new_model = tensorflow.keras.models.load_model('path_to_my_model.h5')

Co ze wszystkimi niestandardowymi operacjami tf i zmiennymi, które nie są częścią obiektu modelu? Czy zostaną w jakiś sposób zapisane, gdy wywołasz save () w modelu? Mam różne niestandardowe wyrażenia utraty i prawdopodobieństwa tensorflow, które są używane w sieci wnioskowania i generacji, ale nie są częścią mojego modelu. Obiekt modelu mojego keras zawiera tylko gęste i konwekcyjne warstwy. W TF 1 właśnie wywołałem metodę save i mogłem być pewien, że wszystkie operacje i tensory zastosowane na moim wykresie zostaną zapisane. W TF2 nie widzę, jak zostaną zapisane operacje, które nie zostały w jakiś sposób dodane do modelu Keras.
Kristof

Czy są jakieś dodatkowe informacje na temat przywracania modeli w TF 2.0? Nie mogę przywrócić wag z plików punktów kontrolnych wygenerowanych za pomocą interfejsu API C, patrz: stackoverflow.com/questions/57944786/...
jregalad 17.09.19


5

tf.keras Zapisywanie modelu za pomocą TF2.0

Widzę świetne odpowiedzi na temat zapisywania modeli za pomocą TF1.x. Chcę podać kilka dodatkowych wskazówek w zapisywaniu tensorflow.kerasmodeli, co jest nieco skomplikowane, ponieważ istnieje wiele sposobów zapisywania modelu.

Podaję przykład zapisywania tensorflow.kerasmodelu w model_pathfolderze w bieżącym katalogu. Działa to dobrze z najnowszym tensorflow (TF2.0). Zaktualizuję ten opis, jeśli w najbliższej przyszłości nastąpi jakakolwiek zmiana.

Zapisywanie i ładowanie całego modelu

import tensorflow as tf
from tensorflow import keras
mnist = tf.keras.datasets.mnist

#import data
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# create a model
def create_model():
  model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation=tf.nn.relu),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
    ])
# compile the model
  model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
  return model

# Create a basic model instance
model=create_model()

model.fit(x_train, y_train, epochs=1)
loss, acc = model.evaluate(x_test, y_test,verbose=1)
print("Original model, accuracy: {:5.2f}%".format(100*acc))

# Save entire model to a HDF5 file
model.save('./model_path/my_model.h5')

# Recreate the exact same model, including weights and optimizer.
new_model = keras.models.load_model('./model_path/my_model.h5')
loss, acc = new_model.evaluate(x_test, y_test)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

Zapisywanie i ładowanie modelu Tylko masy

Jeśli chcesz zapisać tylko masy modelu, a następnie załaduj wagi, aby przywrócić model, to

model.fit(x_train, y_train, epochs=5)
loss, acc = model.evaluate(x_test, y_test,verbose=1)
print("Original model, accuracy: {:5.2f}%".format(100*acc))

# Save the weights
model.save_weights('./checkpoints/my_checkpoint')

# Restore the weights
model = create_model()
model.load_weights('./checkpoints/my_checkpoint')

loss,acc = model.evaluate(x_test, y_test)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

Zapisywanie i przywracanie za pomocą oddzwaniania punktu kontrolnego keras

# include the epoch in the file name. (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path, verbose=1, save_weights_only=True,
    # Save weights, every 5-epochs.
    period=5)

model = create_model()
model.save_weights(checkpoint_path.format(epoch=0))
model.fit(train_images, train_labels,
          epochs = 50, callbacks = [cp_callback],
          validation_data = (test_images,test_labels),
          verbose=0)

latest = tf.train.latest_checkpoint(checkpoint_dir)

new_model = create_model()
new_model.load_weights(latest)
loss, acc = new_model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

zapisywanie modelu z niestandardowymi danymi

import tensorflow as tf
from tensorflow import keras
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Custom Loss1 (for example) 
@tf.function() 
def customLoss1(yTrue,yPred):
  return tf.reduce_mean(yTrue-yPred) 

# Custom Loss2 (for example) 
@tf.function() 
def customLoss2(yTrue, yPred):
  return tf.reduce_mean(tf.square(tf.subtract(yTrue,yPred))) 

def create_model():
  model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation=tf.nn.relu),  
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
    ])
  model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy', customLoss1, customLoss2])
  return model

# Create a basic model instance
model=create_model()

# Fit and evaluate model 
model.fit(x_train, y_train, epochs=1)
loss, acc,loss1, loss2 = model.evaluate(x_test, y_test,verbose=1)
print("Original model, accuracy: {:5.2f}%".format(100*acc))

model.save("./model.h5")

new_model=tf.keras.models.load_model("./model.h5",custom_objects={'customLoss1':customLoss1,'customLoss2':customLoss2})

Zapisywanie modelu Keras z niestandardowymi operacjami

Kiedy mamy niestandardowe operacje, jak w poniższym przypadku ( tf.tile), musimy utworzyć funkcję i owinąć ją warstwą Lambda. W przeciwnym razie model nie zostanie zapisany.

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Lambda
from tensorflow.keras import Model

def my_fun(a):
  out = tf.tile(a, (1, tf.shape(a)[0]))
  return out

a = Input(shape=(10,))
#out = tf.tile(a, (1, tf.shape(a)[0]))
out = Lambda(lambda x : my_fun(x))(a)
model = Model(a, out)

x = np.zeros((50,10), dtype=np.float32)
print(model(x).numpy())

model.save('my_model.h5')

#load the model
new_model=tf.keras.models.load_model("my_model.h5")

Myślę, że omówiłem kilka z wielu sposobów zapisywania modelu tf.keras. Istnieje jednak wiele innych sposobów. Skomentuj poniżej, jeśli widzisz, że Twój przypadek użycia nie jest uwzględniony powyżej. Dzięki!


3

Aby zapisać model, użyj tf.train.Saver, pamiętaj, że jeśli chcesz zmniejszyć rozmiar modelu, musisz podać listę var_list. Val_list może być tf.trainable_variables lub tf.global_variables.


3

Możesz zapisać zmienne w sieci za pomocą

saver = tf.train.Saver() 
saver.save(sess, 'path of save/fileName.ckpt')

Aby przywrócić sieć do ponownego użycia później lub w innym skrypcie, użyj:

saver = tf.train.Saver()
saver.restore(sess, tf.train.latest_checkpoint('path of save/')
sess.run(....) 

Ważne punkty:

  1. sess muszą być takie same między pierwszym a późniejszym przebiegiem (spójna struktura).
  2. saver.restore potrzebuje ścieżki do folderu zapisanych plików, a nie pojedynczej ścieżki pliku.

2

Gdziekolwiek chcesz zapisać model,

self.saver = tf.train.Saver()
with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            ...
            self.saver.save(sess, filename)

Upewnij się, że wszystkie tf.Variablemają nazwy, ponieważ możesz je później przywrócić, używając ich nazw. I gdzie chcesz przewidzieć,

saver = tf.train.import_meta_graph(filename)
name = 'name given when you saved the file' 
with tf.Session() as sess:
      saver.restore(sess, name)
      print(sess.run('W1:0')) #example to retrieve by variable name

Upewnij się, że wygaszacz działa w odpowiedniej sesji. Pamiętaj, że jeśli użyjesz, zostanie użyty tf.train.latest_checkpoint('./')tylko najnowszy punkt kontrolny.


2

Jestem w wersji:

tensorflow (1.13.1)
tensorflow-gpu (1.13.1)

Prosty sposób to

Zapisać:

model.save("model.h5")

Przywracać:

model = tf.keras.models.load_model("model.h5")

2

Dla tensorflow-2.0

to jest bardzo proste.

import tensorflow as tf

ZAPISAĆ

model.save("model_name")

PRZYWRACAĆ

model = tf.keras.models.load_model('model_name')

1

Po odpowiedzi @Vishnuvardhan Janapati, oto kolejny sposób na zapisanie i ponowne załadowanie modelu z niestandardową warstwą / metryką / utratą w TensorFlow 2.0.0

import tensorflow as tf
from tensorflow.keras.layers import Layer
from tensorflow.keras.utils.generic_utils import get_custom_objects

# custom loss (for example)  
def custom_loss(y_true,y_pred):
  return tf.reduce_mean(y_true - y_pred)
get_custom_objects().update({'custom_loss': custom_loss}) 

# custom loss (for example) 
class CustomLayer(Layer):
  def __init__(self, ...):
      ...
  # define custom layer and all necessary custom operations inside custom layer

get_custom_objects().update({'CustomLayer': CustomLayer})  

W ten sposób, kiedy już wykonywane takie kody, a zapisany model z tf.keras.models.save_modellub model.savelub ModelCheckpointoddzwanianie, można ponownie załadować model bez konieczności precyzyjnego niestandardowych obiektów, tak proste, jak

new_model = tf.keras.models.load_model("./model.h5"})

0

W nowej wersji tensorflow 2.0 proces zapisywania / ładowania modelu jest znacznie łatwiejszy. Ze względu na implementację API Keras, API wysokiego poziomu dla TensorFlow.

Aby zapisać model: sprawdź dokumentację w celach informacyjnych: https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/save_model

tf.keras.models.save_model(model_name, filepath, save_format)

Aby załadować model:

https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/load_model

model = tf.keras.models.load_model(filepath)

0

Oto prosty przykład za pomocą Tensorflow 2.0 SavedModel formatu (który jest zalecany format Według docs ) za pomocą prostego zestawu danych MNIST klasyfikatora, wykorzystując Keras API funkcjonalny bez zbyt fantazyjne dzieje:

# Imports
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Flatten
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt

# Load data
mnist = tf.keras.datasets.mnist # 28 x 28
(x_train,y_train), (x_test, y_test) = mnist.load_data()

# Normalize pixels [0,255] -> [0,1]
x_train = tf.keras.utils.normalize(x_train,axis=1)
x_test = tf.keras.utils.normalize(x_test,axis=1)

# Create model
input = Input(shape=(28,28), dtype='float64', name='graph_input')
x = Flatten()(input)
x = Dense(128, activation='relu')(x)
x = Dense(128, activation='relu')(x)
output = Dense(10, activation='softmax', name='graph_output', dtype='float64')(x)
model = Model(inputs=input, outputs=output)

model.compile(optimizer='adam',
             loss='sparse_categorical_crossentropy',
             metrics=['accuracy'])

# Train
model.fit(x_train, y_train, epochs=3)

# Save model in SavedModel format (Tensorflow 2.0)
export_path = 'model'
tf.saved_model.save(model, export_path)

# ... possibly another python program 

# Reload model
loaded_model = tf.keras.models.load_model(export_path) 

# Get image sample for testing
index = 0
img = x_test[index] # I normalized the image on a previous step

# Predict using the signature definition (Tensorflow 2.0)
predict = loaded_model.signatures["serving_default"]
prediction = predict(tf.constant(img))

# Show results
print(np.argmax(prediction['graph_output']))  # prints the class number
plt.imshow(x_test[index], cmap=plt.cm.binary)  # prints the image

Co to jest serving_default?

Jest to nazwa def podpisu wybranego znacznika (w tym przypadku servewybrano domyślny znacznik). Również tutaj wyjaśnia, jak znaleźć-tych tagów i podpisów z wykorzystaniem modelu saved_model_cli.

Zastrzeżenia

To tylko podstawowy przykład, jeśli chcesz go uruchomić, ale w żadnym wypadku nie jest to kompletna odpowiedź - być może uda mi się go zaktualizować w przyszłości. Chciałem tylko podać prosty przykład z wykorzystaniem SavedModelTF 2.0, ponieważ nigdzie nie widziałem takiego, nawet takiego prostego.

Odpowiedź @ Toma to przykład SavedModel, ale nie będzie działać na Tensorflow 2.0, ponieważ niestety są pewne przełomowe zmiany.

@ Odpowiedź Vishnuvardhan Janapati mówi TF 2.0, ale nie dotyczy formatu SavedModel.

Korzystając z naszej strony potwierdzasz, że przeczytałeś(-aś) i rozumiesz nasze zasady używania plików cookie i zasady ochrony prywatności.
Licensed under cc by-sa 3.0 with attribution required.