scikit-learn에서 분류자를 디스크에 저장


192

훈련 된 Naive Bayes 분류기디스크에 저장하고 이를 사용하여 데이터 를 예측하는 방법은 무엇입니까?

scikit-learn 웹 사이트의 다음 샘플 프로그램이 있습니다.

from sklearn import datasets
iris = datasets.load_iris()
from sklearn.naive_bayes import GaussianNB
gnb = GaussianNB()
y_pred = gnb.fit(iris.data, iris.target).predict(iris.data)
print "Number of mislabeled points : %d" % (iris.target != y_pred).sum()

답변:


201

분류기는 다른 객체처럼 피클 및 덤프 할 수있는 객체 일뿐입니다. 예제를 계속하려면 :

import cPickle
# save the classifier
with open('my_dumped_classifier.pkl', 'wb') as fid:
    cPickle.dump(gnb, fid)    

# load it again
with open('my_dumped_classifier.pkl', 'rb') as fid:
    gnb_loaded = cPickle.load(fid)

1
매력처럼 작동합니다! np.savez를 사용하여 다시로드하려고했지만 결코 도움이되지 않았습니다. 고마워
Kartos

7
python3에서는 pickle 모듈을 사용하십시오.
MCSH

213

jobpy.dumpjoblib.load 를 사용 하면 기본 파이썬 선택기보다 숫자 배열을 처리하는 데 훨씬 효율적입니다.

Joblib는 scikit-learn에 포함되어 있습니다.

>>> import joblib
>>> from sklearn.datasets import load_digits
>>> from sklearn.linear_model import SGDClassifier

>>> digits = load_digits()
>>> clf = SGDClassifier().fit(digits.data, digits.target)
>>> clf.score(digits.data, digits.target)  # evaluate training error
0.9526989426822482

>>> filename = '/tmp/digits_classifier.joblib.pkl'
>>> _ = joblib.dump(clf, filename, compress=9)

>>> clf2 = joblib.load(filename)
>>> clf2
SGDClassifier(alpha=0.0001, class_weight=None, epsilon=0.1, eta0=0.0,
       fit_intercept=True, learning_rate='optimal', loss='hinge', n_iter=5,
       n_jobs=1, penalty='l2', power_t=0.5, rho=0.85, seed=0,
       shuffle=False, verbose=0, warm_start=False)
>>> clf2.score(digits.data, digits.target)
0.9526989426822482

편집 : Python 3.8 이상에서는 pickle protocol 5 (기본값이 아님)를 사용하는 경우 큰 숫자 배열을 가진 객체를 속성으로 효율적으로 산세하기 위해 pickle을 사용할 수 있습니다.


1
그러나 단일 작업 흐름의 일부 인 파이프 라인 작업은 이해합니다. 모델을 빌드하려면 디스크에 저장하고 실행을 중지하십시오. 그런 다음 일주일 후 돌아와 디스크에서 모델을로드하려고하면 오류가 발생합니다.
venuktan

2
원하는 fit방법 인 경우 메소드 실행을 중지했다가 다시 시작할 수있는 방법이 없습니다 . 즉 , 동일한 버전의 scikit-learn 라이브러리가있는 Python에서 호출하면 joblib.load성공한 후에 예외를 발생 joblib.dump시키지 않아야합니다.
ogrisel

10
IPython을 사용하는 경우 암시 적 네임 스페이스 오버로드가 산세 프로세스를 중단하는 것으로 알려져 있으므로 --pylab명령 행 플래그 또는 %pylab매직을 사용하지 마십시오 . %matplotlib inline대신 명시적인 가져 오기와 마법을 사용하십시오.
ogrisel

2
참조 : scikit-learn 문서를 참조하십시오 : scikit-learn.org/stable/tutorial/basic/…
user1448319

1
이전에 저장된 모델을 재교육 할 수 있습니까? 특히 SVC 모델?
Uday Sawant

108

당신이 찾고있는 것은 sklearn 단어에서 모델 지속성 이라고 하며 소개모델 지속성 섹션 에 문서화되어 있습니다.

분류기를 초기화하고 오랫동안 훈련했습니다.

clf = some.classifier()
clf.fit(X, y)

이 후 두 가지 옵션이 있습니다.

1) 피클 사용

import pickle
# now you can save it to a file
with open('filename.pkl', 'wb') as f:
    pickle.dump(clf, f)

# and later you can load it
with open('filename.pkl', 'rb') as f:
    clf = pickle.load(f)

2) Joblib 사용

from sklearn.externals import joblib
# now you can save it to a file
joblib.dump(clf, 'filename.pkl') 
# and later you can load it
clf = joblib.load('filename.pkl')

한 번 더 위에서 언급 한 링크를 읽는 것이 도움이됩니다.


30

많은 경우에, 특히 텍스트 분류의 경우 분류기를 저장하는 것만으로는 충분하지 않지만 향후 입력을 벡터화 할 수 있도록 벡터 라이저도 저장해야합니다.

import pickle
with open('model.pkl', 'wb') as fout:
  pickle.dump((vectorizer, clf), fout)

향후 사용 사례 :

with open('model.pkl', 'rb') as fin:
  vectorizer, clf = pickle.load(fin)

X_new = vectorizer.transform(new_samples)
X_new_preds = clf.predict(X_new)

벡터 라이저를 덤프하기 전에 다음과 같이 벡터 라이저의 stop_words_ 속성을 삭제할 수 있습니다.

vectorizer.stop_words_ = None

보다 효율적으로 덤핑 할 수 있습니다. 또한 분류 자 ​​매개 변수가 희소 인 경우 (대부분의 텍스트 분류 예제에서와 같이) 매개 변수를 조밀에서 희소로 변환하여 메모리 소비,로드 및 덤프 측면에서 큰 차이를 만들 수 있습니다. 다음을 통해 모델을 희소 화 합니다.

clf.sparsify()

SGDClassifier에 대해 자동으로 작동 하지만 모델이 드문 경우 (clf.coef_의 많은 0) 다음을 통해 clf.coef_를 csr scipy sparse matrix 로 수동으로 변환 할 수 있습니다 .

clf.coef_ = scipy.sparse.csr_matrix(clf.coef_)

더 효율적으로 저장할 수 있습니다.


통찰력있는 답변! SVC의 경우 추가하고 싶을 때 희소 모델 매개 변수를 반환합니다.
Shayan Amani 14

5

sklearn추정기는 추정 된 훈련 된 특성을 쉽게 저장할 수있는 방법을 구현합니다. 어떤 추정기는 __getstate__메소드 자체를 구현 하지만, 다른 것들은 단순히 객체 내부 사전을 저장하는 기본 구현GMM사용합니다 .

def __getstate__(self):
    try:
        state = super(BaseEstimator, self).__getstate__()
    except AttributeError:
        state = self.__dict__.copy()

    if type(self).__module__.startswith('sklearn.'):
        return dict(state.items(), _sklearn_version=__version__)
    else:
        return state

모델을 디스크에 저장하는 권장 방법은 pickle모듈 을 사용하는 것입니다 .

from sklearn import datasets
from sklearn.svm import SVC
iris = datasets.load_iris()
X = iris.data[:100, :2]
y = iris.target[:100]
model = SVC()
model.fit(X,y)
import pickle
with open('mymodel','wb') as f:
    pickle.dump(model,f)

그러나 향후에 모델을 재교육하거나 이전 버전의 sklearn에 고정되는 등 심각한 결과를 겪을 수 있도록 추가 데이터를 저장해야합니다 .

로부터 문서 :

이후 버전의 scikit-learn을 사용하여 유사한 모델을 재 구축하려면 추가 된 메타 데이터를 절인 모델과 함께 저장해야합니다.

훈련 데이터 (예 : 불변 스냅 샷에 대한 참조)

모델을 생성하는 데 사용되는 파이썬 소스 코드

scikit-learn의 버전 및 해당 종속성

교육 데이터에서 얻은 교차 검증 점수

tree.pyxCython으로 작성된 모듈에 의존하는 Ensemble 추정기 (예 IsolationForest:)는 구현에 커플 링을 생성하므로 sklearn 버전간에 안정적으로 보장되지 않기 때문에 특히 그렇습니다. 과거에는 호환되지 않는 변경 사항이 있습니다.

모델이 매우 커지고 로딩이 성가신 경우 더 효율적으로 사용할 수도 있습니다 joblib. 설명서에서 :

scikit의 특정 경우에는 joblib의 pickle( joblib.dump& joblib.load) 대체를 사용 하는 것이 더 흥미로울 수 있습니다. 문자열이 아닌 디스크에 :


1
but can only pickle to the disk and not to a string그러나 joblib에서 이것을 StringIO로 피클 할 수 있습니다. 이것이 제가 항상하는 일입니다.
Matthew

내 현재 프로젝트가 비슷한 것을하고 The training data, e.g. a reference to a immutable snapshot있습니다. 여기서 무엇을 알고 있습니까? 티아!
데이지 진

1

sklearn.externals.joblib이후 더 이상 사용되지 않으며0.21 다음에서 제거됩니다 v0.23.

/usr/local/lib/python3.7/site-packages/sklearn/externals/joblib/ init .py : 15 : FutureWarning : sklearn.externals.joblib는 0.21에서 더 이상 사용되지 않으며 0.23에서 제거됩니다. 이 기능을 joblib에서 직접 가져 오십시오. pip install joblib을 사용하여 설치할 수 있습니다. 절인 모델을로드 할 때이 경고가 발생하면 scikit-learn 0.21+를 사용하여 해당 모델을 다시 직렬화해야합니다.
warnings.warn (msg, category = FutureWarning)


따라서 다음을 설치해야합니다 joblib.

pip install joblib

마지막으로 모델을 디스크에 씁니다.

import joblib
from sklearn.datasets import load_digits
from sklearn.linear_model import SGDClassifier


digits = load_digits()
clf = SGDClassifier().fit(digits.data, digits.target)

with open('myClassifier.joblib.pkl', 'wb') as f:
    joblib.dump(clf, f, compress=9)

덤프 된 파일을 읽으려면 다음을 수행해야합니다.

with open('myClassifier.joblib.pkl', 'rb') as f:
    my_clf = joblib.load(f)
당사 사이트를 사용함과 동시에 당사의 쿠키 정책개인정보 보호정책을 읽고 이해하였음을 인정하는 것으로 간주합니다.
Licensed under cc by-sa 3.0 with attribution required.