0%

scikit-learn 学习笔记(二)——模型保存

参考:

训练完成的模型需要保存在本地,以便之后的使用。在 scikit-learn 中,保存模型有两种方法,分别为 picklejoblib

使用 pickle 模组存储训练模型

pickle 模组是 Python 语言内置的模型保存包,可以用于保存 scikit-learn 的模型。

其实现方法如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import pickle  # Import the pickle module for model saving.

from sklearn import svm
from sklearn import datasets

# Train a classification model.
clf = svm.SVC(gamma='scale')
iris = datasets.load_iris()
X, y = iris.data, iris.target
clf.fit(X, y)

# Use pickle module to save the model.
s = pickle.dumps(clf)

# Load the model.
clf2 = pickle.loads(s)
print(clf2.predict(X[0:1])) # array([0])
print(y[0]) # 0

使用 joblib 模组存储训练模型

joblib 模组也可以用于保存训练:

1
2
3
4
from joblib import dump, load
dump(clf, 'filename.joblib')

clf = load('filename.joblib')