Najlepszy sposób na zapisanie wytrenowanego modelu w PyTorch?


193

Szukałem alternatywnych sposobów zapisania wytrenowanego modelu w PyTorch. Jak dotąd znalazłem dwie alternatywy.

  1. torch.save (), aby zapisać model i torch.load (), aby załadować model.
  2. model.state_dict (), aby zapisać wytrenowany model i model.load_state_dict (), aby załadować zapisany model.

Natknąłem się na tę dyskusję, w której podejście 2 jest zalecane zamiast podejścia 1.

Moje pytanie brzmi: dlaczego preferowane jest drugie podejście? Czy to tylko dlatego, że moduły torch.nn mają te dwie funkcje i zachęcamy do ich używania?


2
Myślę, że dzieje się tak, ponieważ torch.save () zapisuje również wszystkie zmienne pośrednie, takie jak wyjścia pośrednie do użytku z propagacją wsteczną. Ale wystarczy zapisać parametry modelu, takie jak waga / odchylenie itp. Czasami ten pierwszy może być znacznie większy niż drugi.
Dawei Yang

2
Testowałem torch.save(model, f)i torch.save(model.state_dict(), f). Zapisane pliki mają ten sam rozmiar. Teraz jestem zmieszany. Zauważyłem również, że używanie pickle do zapisywania model.state_dict () jest bardzo wolne. Myślę, że najlepszym sposobem jest użycie, torch.save(model.state_dict(), f)ponieważ zajmujesz się tworzeniem modelu, a latarka obsługuje ładowanie ciężarów modelu, eliminując w ten sposób możliwe problemy. Źródła
Dawei Yang

Wygląda na to, że PyTorch zajął się tym nieco bardziej szczegółowo w sekcji samouczków - jest tam wiele dobrych informacji, których nie ma w odpowiedziach, w tym zapisywanie więcej niż jednego modelu na raz i ciepłe modele startowe.
whlteXbread

co jest złego w używaniu pickle?
Charlie Parker

1
@CharlieParker torch.save jest oparty na marynacie. Poniższy tekst pochodzi z samouczka, do którego link znajduje się powyżej: „[torch.save] zapisze cały moduł za pomocą modułu pikle w języku Python. Wadą tego podejścia jest to, że serializowane dane są powiązane z określonymi klasami i dokładną strukturą katalogów używaną w modelu jest zapisany. Powodem tego jest to, że pickle nie zapisuje samej klasy modelu. Zamiast tego zapisuje ścieżkę do pliku zawierającego klasę, która jest używana podczas ładowania. Z tego powodu kod może się zepsuć na różne sposoby, gdy używany w innych projektach lub po refaktorach. "
David Miller

Odpowiedzi:


215

Znalazłem tę stronę w ich repozytorium github, po prostu wkleję tutaj zawartość.


Zalecane podejście do zapisywania modelu

Istnieją dwa główne podejścia do serializacji i przywracania modelu.

Pierwsza (zalecana) zapisuje i wczytuje tylko parametry modelu:

torch.save(the_model.state_dict(), PATH)

Później:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

Drugi zapisuje i wczytuje cały model:

torch.save(the_model, PATH)

Później:

the_model = torch.load(PATH)

Jednak w tym przypadku serializowane dane są powiązane z określonymi klasami i dokładną używaną strukturą katalogów, więc mogą ulec uszkodzeniu na różne sposoby, gdy są używane w innych projektach lub po poważnych refaktorach.


8
Zgodnie z @smth discuss.pytorch.org/t/saving-and-loading-a-model-in-pytorch/ ... model domyślnie ładuje się ponownie, aby trenować model. więc trzeba ręcznie wywołać the_model.eval () po załadowaniu, jeśli ładujesz go w celu wnioskowania, a nie wznawiania uczenia.
WillZ

druga metoda daje stackoverflow.com/questions/53798009/ ... błąd w systemie Windows 10. nie był w stanie go rozwiązać
Gulzar

Czy jest jakaś opcja zapisywania bez potrzeby dostępu do klasy modelu?
Michael D

Przy takim podejściu, jak śledzić * args i ** kwargs, które musisz przekazać dla przypadku obciążenia?
Mariano Kamp

co jest złego w używaniu pickle?
Charlie Parker

144

To zależy od tego, co chcesz robić.

Przypadek 1: Zapisz model, aby użyć go samodzielnie do wnioskowania : Zapisujesz model, przywracasz go, a następnie zmieniasz model w tryb oceny. Dzieje się tak, ponieważ zwykle masz warstwy BatchNormi, Dropoutktóre domyślnie są w trybie pociągu podczas budowy:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

Przypadek 2: Zapisz model, aby wznowić uczenie później : Jeśli chcesz dalej trenować model, który zamierzasz zapisać, musisz zapisać więcej niż tylko model. Musisz także zapisać stan optymalizatora, epoki, wynik itp. Zrobisz to w następujący sposób:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

Aby wznowić szkolenie, wykonaj następujące czynności:, state = torch.load(filepath)a następnie, aby przywrócić stan każdego pojedynczego obiektu, coś takiego:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

Ponieważ wznawiasz szkolenie, NIE dzwońmodel.eval() po przywróceniu stanów podczas ładowania.

Przypadek # 3: Model, który ma być używany przez inną osobę bez dostępu do Twojego kodu : W Tensorflow możesz utworzyć .pbplik, który definiuje zarówno architekturę, jak i wagi modelu. Jest to bardzo przydatne, szczególnie podczas używania Tensorflow serve. Równoważny sposób na zrobienie tego w Pytorch to:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

W ten sposób nadal nie jest kuloodporny, a ponieważ pytorch wciąż przechodzi wiele zmian, nie polecałbym go.


1
Czy istnieje zalecane zakończenie pliku dla 3 przypadków? A może zawsze .pth?
Verena Haunschmid

1
W przypadku # 3 torch.loadzwraca tylko OrderedDict. Jak uzyskać model, aby móc przewidywać?
Alber8295,

Cześć, czy mogę wiedzieć, jak wykonać wspomniany „Przypadek 2: Zapisz model, aby później wznowić szkolenie”? Udało mi się załadować punkt kontrolny do modelu, a następnie nie mogę uruchomić lub wznowić trenowania modelu, takiego jak „model.to (urządzenie) model = train_model_epoch (model, kryterium, optymalizator, harmonogram, epoki)”
dnez

1
Cześć, w pierwszym przypadku, który jest do wnioskowania, w oficjalnym dokumencie pytorcha powiedz, że należy zapisać optymalizator state_dict dla wnioskowania lub zakończenia szkolenia. „Podczas zapisywania ogólnego punktu kontrolnego, który ma być użyty do wnioskowania lub wznawiania uczenia, należy zapisać więcej niż tylko atrybut state_dict modelu. Ważne jest również zapisanie parametru state_dict optymalizatora, ponieważ zawiera on bufory i parametry, które są aktualizowane w miarę uczenia się modelu . ”
Mohammed Awney

1
W przypadku # 3 należy gdzieś zdefiniować klasę modelu.
Michael D

12

marynata Python implementuje protokoły binarne do serializacji i deserializacji obiektu Pythona.

Kiedy ty import torch(lub gdy używasz PyTorch) będzie to import pickledla ciebie i nie musisz wywoływać pickle.dump()i pickle.load()bezpośrednio, które są metodami zapisywania i ładowania obiektu.

W rzeczywistości torch.save()i torch.load()zawinie pickle.dump()i pickle.load()dla Ciebie.

ZA state_dictDruga odpowiedź wspomniano zasługuje tylko kilka dodatkowych uwag.

Co state_dictmamy w PyTorch? Właściwie są dwastate_dict .

Model PyTorch torch.nn.Modulema model.parameters()wywołanie, aby uzyskać parametry, których można się nauczyć (w i b). Te parametry, których można się nauczyć, raz ustawione losowo, będą aktualizowane w miarę upływu czasu. Parametry, których można się nauczyć, są pierwszymistate_dict .

Drugi state_dictto dyktowanie stanu optymalizatora. Przypominasz sobie, że optymalizator służy do poprawy parametrów, których można się nauczyć. Ale optymalizatorstate_dict jest naprawiony. Nie ma się tam czego nauczyć.

Ponieważ state_dictobiekty są słownikami Pythona, można je łatwo zapisywać, aktualizować, zmieniać i przywracać, dodając wiele modułowości do modeli i optymalizatorów PyTorch.

Stwórzmy super prosty model, aby to wyjaśnić:

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Ten kod wyświetli następujące informacje:

Model's state_dict:
weight   torch.Size([2, 5])
bias     torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]

Zwróć uwagę, że jest to model minimalny. Możesz spróbować dodać stos sekwencyjny

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

Należy zauważyć, że tylko warstwy z parametrami, których można się nauczyć (warstwy splotowe, warstwy liniowe itp.) I zarejestrowane bufory (warstwy normalne partii) mają wpisy w modelu state_dict.

Rzeczy, których nie można się nauczyć, należą do obiektu optymalizatora state_dict , który zawiera informacje o stanie optymalizatora, a także o zastosowanych hiperparametrach.

Reszta historii jest taka sama; w fazie wnioskowania (jest to faza, w której używamy modelu po treningu) do prognozowania; przewidujemy na podstawie parametrów, których się nauczyliśmy. Tak więc do wnioskowania wystarczy zapisać parametry model.state_dict().

torch.save(model.state_dict(), filepath)

I użyć później model.load_state_dict (torch.load (filepath)) model.eval ()

Uwaga: nie zapomnij o ostatniej linii, model.eval()jest to kluczowe po załadowaniu modelu.

Nie próbuj też oszczędzać torch.save(model.parameters(), filepath). To model.parameters()tylko obiekt generatora.

Z drugiej strony torch.save(model, filepath)zapisuje sam obiekt modelu, ale pamiętaj, że model nie ma optymalizatora state_dict. Sprawdź inną doskonałą odpowiedź autorstwa @Jadiel de Armas, aby zapisać dyktando stanu optymalizatora.


Chociaż nie jest to proste rozwiązanie, istota problemu jest dogłębnie przeanalizowana! Głosuj za.
Jason Young

7

Powszechną konwencją PyTorch jest zapisywanie modeli przy użyciu rozszerzenia pliku .pt lub .pth.

Zapisz / wczytaj cały model Zapisz:

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

Załaduj:

Klasa modelu musi być gdzieś zdefiniowana

model = torch.load(PATH)
model.eval()

4

Jeśli chcesz zapisać model i później wznowić trening:

Pojedynczy GPU: Zapisz:

state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Załaduj:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

Wiele GPU: Zapisz

state = {
        'epoch': epoch,
        'state_dict': model.module.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Załaduj:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

#Don't call DataParallel before loading the model otherwise you will get an error

model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU
Korzystając z naszej strony potwierdzasz, że przeczytałeś(-aś) i rozumiesz nasze zasady używania plików cookie i zasady ochrony prywatności.
Licensed under cc by-sa 3.0 with attribution required.