Zapisz klasyfikator na dysku w scikit-learn


192

Jak zapisać wyszkolonego klasyfikatora Naive Bayes na dysk i użyć go do przewidywania danych?

Mam następujący przykładowy program ze strony scikit-learn:

from sklearn import datasets
iris = datasets.load_iris()
from sklearn.naive_bayes import GaussianNB
gnb = GaussianNB()
y_pred = gnb.fit(iris.data, iris.target).predict(iris.data)
print "Number of mislabeled points : %d" % (iris.target != y_pred).sum()

Odpowiedzi:


201

Klasyfikatory to po prostu obiekty, które można marynować i zrzucić jak każdy inny. Aby kontynuować przykład:

import cPickle
# save the classifier
with open('my_dumped_classifier.pkl', 'wb') as fid:
    cPickle.dump(gnb, fid)    

# load it again
with open('my_dumped_classifier.pkl', 'rb') as fid:
    gnb_loaded = cPickle.load(fid)

1
Działa jak marzenie! Próbowałem użyć np.savez i ponownie go załadować i to nigdy nie pomogło. Wielkie dzięki.
Kartos

7
w python3 użyj modułu marynowanego, który działa dokładnie tak.
MCSH,

213

Możesz także użyć joblib.dump i joblib.load, który jest znacznie bardziej wydajny w obsłudze tablic numerycznych niż domyślny program wybierający Pythona.

Joblib jest zawarty w scikit-learn:

>>> import joblib
>>> from sklearn.datasets import load_digits
>>> from sklearn.linear_model import SGDClassifier

>>> digits = load_digits()
>>> clf = SGDClassifier().fit(digits.data, digits.target)
>>> clf.score(digits.data, digits.target)  # evaluate training error
0.9526989426822482

>>> filename = '/tmp/digits_classifier.joblib.pkl'
>>> _ = joblib.dump(clf, filename, compress=9)

>>> clf2 = joblib.load(filename)
>>> clf2
SGDClassifier(alpha=0.0001, class_weight=None, epsilon=0.1, eta0=0.0,
       fit_intercept=True, learning_rate='optimal', loss='hinge', n_iter=5,
       n_jobs=1, penalty='l2', power_t=0.5, rho=0.85, seed=0,
       shuffle=False, verbose=0, warm_start=False)
>>> clf2.score(digits.data, digits.target)
0.9526989426822482

Edycja: w Pythonie 3.8+ można teraz używać piklowania do efektywnego trawienia obiektów z dużymi tablicami numerycznymi jako atrybutami, jeśli używasz protokołu piklowania 5 (co nie jest domyślne).


1
Ale z mojego zrozumienia, potokowanie działa, jeśli jest częścią jednego przepływu pracy. Jeśli chcę zbudować model, zapisz go na dysku i zatrzymaj tam wykonanie. Potem wracam tydzień później i próbuję załadować model z dysku, co powoduje błąd:
venuktan

2
Nie ma sposobu, aby zatrzymać i wznowić wykonanie fitmetody, jeśli tego właśnie szukasz. To powiedziawszy, joblib.loadnie powinno zgłaszać wyjątku po pomyślnym joblib.dumpwywołaniu go z Pythona z tą samą wersją biblioteki scikit-learn.
ogrisel

10
Jeśli używasz IPython, nie używaj --pylabflagi wiersza poleceń ani %pylabmagii, ponieważ wiadomo, że niejawne przeciążenie przestrzeni nazw przerywa proces wytrawiania. %matplotlib inlineZamiast tego użyj jawnego importu i magii.
ogrisel

2
zobacz dokumentację scikit-learn w celach informacyjnych: scikit-learn.org/stable/tutorial/basic/…
user1448319

1
Czy można ponownie przeszkolić wcześniej zapisany model? W szczególności modele SVC?
Uday Sawant

108

To, czego szukasz, nazywa się trwałością modelu w sklearnach i jest udokumentowane we wstępie oraz w sekcjach trwałości modelu .

Więc zainicjalizowałeś swój klasyfikator i trenowałeś go przez długi czas

clf = some.classifier()
clf.fit(X, y)

Po tym masz dwie opcje:

1) Za pomocą marynaty

import pickle
# now you can save it to a file
with open('filename.pkl', 'wb') as f:
    pickle.dump(clf, f)

# and later you can load it
with open('filename.pkl', 'rb') as f:
    clf = pickle.load(f)

2) Korzystanie z Joblib

from sklearn.externals import joblib
# now you can save it to a file
joblib.dump(clf, 'filename.pkl') 
# and later you can load it
clf = joblib.load('filename.pkl')

Jeszcze raz pomocne jest przeczytanie wyżej wymienionych linków


30

W wielu przypadkach, szczególnie w przypadku klasyfikacji tekstowej, nie wystarczy po prostu przechowywać klasyfikator, ale trzeba również przechowywać wektoryzator, aby wektoryzować dane wejściowe w przyszłości.

import pickle
with open('model.pkl', 'wb') as fout:
  pickle.dump((vectorizer, clf), fout)

przyszły przypadek użycia:

with open('model.pkl', 'rb') as fin:
  vectorizer, clf = pickle.load(fin)

X_new = vectorizer.transform(new_samples)
X_new_preds = clf.predict(X_new)

Przed zrzuceniem wektoryzatora można usunąć właściwość stop_words_ wektoryzatora poprzez:

vectorizer.stop_words_ = None

aby dumping był bardziej wydajny. Również jeśli parametry klasyfikatora są rzadkie (jak w większości przykładów klasyfikacji tekstu), możesz przekonwertować parametry z gęstego na rzadkie, co spowoduje ogromną różnicę pod względem zużycia pamięci, ładowania i zrzutu. Sparsify model przez:

clf.sparsify()

Który automatycznie będzie działał dla SGDClassifier, ale jeśli wiesz, że twój model jest rzadki (dużo zer w pliku clf.coef_), możesz ręcznie przekonwertować plik clf.coef_ na csr scipy rzadką macierz poprzez:

clf.coef_ = scipy.sparse.csr_matrix(clf.coef_)

a następnie możesz przechowywać go bardziej wydajnie.


Wnikliwa odpowiedź! Chciałem tylko dodać w przypadku SVC, zwraca rzadki parametr modelu.
Shayan Amani

5

sklearnestymatory wdrażają metody ułatwiające zapisywanie odpowiednich przeszkolonych właściwości estymatora. Niektóre estymatory implementują __getstate__same metody, ale inne, jak na przykład GMMużycie podstawowej implementacji, która po prostu zapisuje wewnętrzny słownik obiektów:

def __getstate__(self):
    try:
        state = super(BaseEstimator, self).__getstate__()
    except AttributeError:
        state = self.__dict__.copy()

    if type(self).__module__.startswith('sklearn.'):
        return dict(state.items(), _sklearn_version=__version__)
    else:
        return state

Zalecaną metodą zapisania modelu na dysku jest użycie picklemodułu:

from sklearn import datasets
from sklearn.svm import SVC
iris = datasets.load_iris()
X = iris.data[:100, :2]
y = iris.target[:100]
model = SVC()
model.fit(X,y)
import pickle
with open('mymodel','wb') as f:
    pickle.dump(model,f)

Powinieneś jednak zapisać dodatkowe dane, abyś mógł ponownie przeszkolić swój model w przyszłości lub ponieść straszne konsekwencje (takie jak zamknięcie się w starej wersji sklearn) .

Z dokumentacji :

Aby odbudować podobny model w przyszłych wersjach scikit-learn, dodatkowe metadane powinny zostać zapisane wzdłuż marynowanego modelu:

Dane treningowe, np. Odniesienie do niezmiennej migawki

Kod źródłowy Pythona użyty do wygenerowania modelu

Wersje scikit-learn i jego zależności

Wynik weryfikacji krzyżowej uzyskany na podstawie danych treningowych

Jest to szczególnie prawdziwe w przypadku estymatorówtree.pyx zestawów, które opierają się na module napisanym w Cython (np. IsolationForest), Ponieważ tworzy ono sprzężenie z implementacją, co nie gwarantuje stabilności między wersjami sklearn. W przeszłości widział niezgodne wstecz wstecz zmiany.

Jeśli twoje modele stają się bardzo duże, a ładowanie staje się uciążliwe, możesz również użyć bardziej wydajnych joblib. Z dokumentacji:

W konkretnym przypadku scikit może być bardziej interesujące użycie zamiennika pickle( joblib.dump& joblib.load) Joblib , który jest bardziej wydajny w przypadku obiektów, które niosą duże tablice numpy wewnętrznie, jak to często bywa w przypadku dopasowanych estymatorów uczenia się, ale może tylko zalewać na dysk, a nie na ciąg:


1
but can only pickle to the disk and not to a stringAle możesz zalać to w StringIO z joblib. To właśnie robię cały czas.
Matthew

Mój obecny projekt robi coś podobnego, wiesz co The training data, e.g. a reference to a immutable snapshottutaj? TIA!
Daisy Qin

1

sklearn.externals.joblibzostała zaniechana , ponieważ 0.21i zostaną usunięte w v0.23:

/usr/local/lib/python3.7/site-packages/sklearn/externals/joblib/ init .py: 15: FutureWarning: sklearn.externals.joblib jest przestarzałe w 0.21 i zostanie usunięte w 0.23. Zaimportuj tę funkcjonalność bezpośrednio z joblib, którą można zainstalować za pomocą: pip install joblib. Jeśli to ostrzeżenie zostanie wyświetlone podczas ładowania modeli marynowanych, może być konieczna ponowna serializacja tych modeli za pomocą scikit-learn 0.21+.
warnings.warn (msg, category = FutureWarning)


Dlatego musisz zainstalować joblib:

pip install joblib

i na koniec wypisz model na dysk:

import joblib
from sklearn.datasets import load_digits
from sklearn.linear_model import SGDClassifier


digits = load_digits()
clf = SGDClassifier().fit(digits.data, digits.target)

with open('myClassifier.joblib.pkl', 'wb') as f:
    joblib.dump(clf, f, compress=9)

Teraz, aby odczytać zrzucony plik, wystarczy uruchomić:

with open('myClassifier.joblib.pkl', 'rb') as f:
    my_clf = joblib.load(f)
Korzystając z naszej strony potwierdzasz, że przeczytałeś(-aś) i rozumiesz nasze zasady używania plików cookie i zasady ochrony prywatności.
Licensed under cc by-sa 3.0 with attribution required.