marynata Python implementuje protokoły binarne do serializacji i deserializacji obiektu Pythona.
Kiedy ty import torch
(lub gdy używasz PyTorch) będzie to import pickle
dla ciebie i nie musisz wywoływać pickle.dump()
i pickle.load()
bezpośrednio, które są metodami zapisywania i ładowania obiektu.
W rzeczywistości torch.save()
i torch.load()
zawinie pickle.dump()
i pickle.load()
dla Ciebie.
ZA state_dict
Druga odpowiedź wspomniano zasługuje tylko kilka dodatkowych uwag.
Co state_dict
mamy w PyTorch? Właściwie są dwastate_dict
.
Model PyTorch torch.nn.Module
ma model.parameters()
wywołanie, aby uzyskać parametry, których można się nauczyć (w i b). Te parametry, których można się nauczyć, raz ustawione losowo, będą aktualizowane w miarę upływu czasu. Parametry, których można się nauczyć, są pierwszymistate_dict
.
Drugi state_dict
to dyktowanie stanu optymalizatora. Przypominasz sobie, że optymalizator służy do poprawy parametrów, których można się nauczyć. Ale optymalizatorstate_dict
jest naprawiony. Nie ma się tam czego nauczyć.
Ponieważ state_dict
obiekty są słownikami Pythona, można je łatwo zapisywać, aktualizować, zmieniać i przywracać, dodając wiele modułowości do modeli i optymalizatorów PyTorch.
Stwórzmy super prosty model, aby to wyjaśnić:
import torch
import torch.optim as optim
model = torch.nn.Linear(5, 2)
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
print("Model weight:")
print(model.weight)
print("Model bias:")
print(model.bias)
print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
Ten kod wyświetli następujące informacje:
Model's state_dict:
weight torch.Size([2, 5])
bias torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328, 0.1360, 0.1553, -0.1838, -0.0316],
[ 0.0479, 0.1760, 0.1712, 0.2244, 0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]
Zwróć uwagę, że jest to model minimalny. Możesz spróbować dodać stos sekwencyjny
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.Conv2d(A, B, C)
torch.nn.Linear(H, D_out),
)
Należy zauważyć, że tylko warstwy z parametrami, których można się nauczyć (warstwy splotowe, warstwy liniowe itp.) I zarejestrowane bufory (warstwy normalne partii) mają wpisy w modelu state_dict
.
Rzeczy, których nie można się nauczyć, należą do obiektu optymalizatora state_dict
, który zawiera informacje o stanie optymalizatora, a także o zastosowanych hiperparametrach.
Reszta historii jest taka sama; w fazie wnioskowania (jest to faza, w której używamy modelu po treningu) do prognozowania; przewidujemy na podstawie parametrów, których się nauczyliśmy. Tak więc do wnioskowania wystarczy zapisać parametry model.state_dict()
.
torch.save(model.state_dict(), filepath)
I użyć później model.load_state_dict (torch.load (filepath)) model.eval ()
Uwaga: nie zapomnij o ostatniej linii, model.eval()
jest to kluczowe po załadowaniu modelu.
Nie próbuj też oszczędzać torch.save(model.parameters(), filepath)
. To model.parameters()
tylko obiekt generatora.
Z drugiej strony torch.save(model, filepath)
zapisuje sam obiekt modelu, ale pamiętaj, że model nie ma optymalizatora state_dict
. Sprawdź inną doskonałą odpowiedź autorstwa @Jadiel de Armas, aby zapisać dyktando stanu optymalizatora.