scikit-learn (sklearn)是Python環(huán)境下常見的機器學習庫,包含了常見的分類、回歸和聚類算法。在訓練模型之后,常見的操作是對模型進行可視化,則需要使用Matplotlib進行展示。
scikit-plot是一個基于sklearn和Matplotlib的庫,主要的功能是對訓練好的模型進行可視化,功能比較簡單易懂。
https://scikit-plot.readthedocs.io
pip?install?scikit-plot
功能1:評估指標可視化scikitplot.metrics.plot_confusion_matrix快速展示模型預測結(jié)果和標簽計算得到的混淆矩陣。
import?scikitplot?as?skplt
rf?=?RandomForestClassifier()
rf?=?rf.fit(X_train,?y_train)
y_pred?=?rf.predict(X_test)
skplt.metrics.plot_confusion_matrix(y_test,?y_pred,?normalize=True)
plt.show()
scikitplot.metrics.plot_roc快速展示模型預測的每個類別的ROC曲線。
import?scikitplot?as?skplt
nb?=?GaussianNB()
nb?=?nb.fit(X_train,?y_train)
y_probas?=?nb.predict_proba(X_test)
skplt.metrics.plot_roc(y_test,?y_probas)
plt.show()
scikitplot.metrics.plot_ks_statistic
import?scikitplot?as?skplt
lr?=?LogisticRegression()
lr?=?lr.fit(X_train,?y_train)
y_probas?=?lr.predict_proba(X_test)
skplt.metrics.plot_ks_statistic(y_test,?y_probas)
plt.show()
scikitplot.metrics.plot_precision_recall
import?scikitplot?as?skplt
nb?=?GaussianNB()
nb.fit(X_train,?y_train)
y_probas?=?nb.predict_proba(X_test)
skplt.metrics.plot_precision_recall(y_test,?y_probas)
plt.show()
scikitplot.metrics.plot_silhouette對聚類結(jié)果進行silhouette analysis分
import?scikitplot?as?skplt
kmeans?=?KMeans(n_clusters=4,?random_state=1)
cluster_labels?=?kmeans.fit_predict(X)
skplt.metrics.plot_silhouette(X,?cluster_labels)
plt.show()
scikitplot.metrics.plot_calibration_curve
import?scikitplot?as?skplt
rf?=?RandomForestClassifier()
lr?=?LogisticRegression()
nb?=?GaussianNB()
svm?=?LinearSVC()
rf_probas?=?rf.fit(X_train,?y_train).predict_proba(X_test)
lr_probas?=?lr.fit(X_train,?y_train).predict_proba(X_test)
nb_probas?=?nb.fit(X_train,?y_train).predict_proba(X_test)
svm_scores?=?svm.fit(X_train,?y_train).decision_function(X_test)
probas_list?=?[rf_probas,?lr_probas,?nb_probas,?svm_scores]
clf_names?=?['Random?Forest',?'Logistic?Regression',
'Gaussian?Naive?Bayes',?'Support?Vector?Machine']
skplt.metrics.plot_calibration_curve(y_test,
probas_list,
clf_names)
plt.show()
功能2:模型可視化scikitplot.estimators.plot_learning_curve
import?scikitplot?as?skplt
rf?=?RandomForestClassifier()
skplt.estimators.plot_learning_curve(rf,?X,?y)
plt.show()
scikitplot.estimators.plot_feature_importances可視化特征重要性。
import?scikitplot?as?skplt
rf?=?RandomForestClassifier()
rf.fit(X,?y)
skplt.estimators.plot_feature_importances(
rf,?feature_names=['petal?length',?'petal?width',
'sepal?length',?'sepal?width'])
plt.show()
功能3:聚類可視化scikitplot.cluster.plot_elbow_curve
import?scikitplot?as?skplt
kmeans?=?KMeans(random_state=1)
skplt.cluster.plot_elbow_curve(kmeans,?cluster_ranges=range(1,?30))
plt.show()
功能4:降維可視化scikitplot.decomposition.plot_pca_component_variance繪制 PCA 分量的解釋方差比。import?scikitplot?as?skplt
pca?=?PCA(random_state=1)
pca.fit(X)
skplt.decomposition.plot_pca_component_variance(pca)
>plt.show()
scikitplot.decomposition.plot_pca_2d_projectionimport?scikitplot?as?skplt
pca?=?PCA(random_state=1)
pca.fit(X)
skplt.decomposition.plot_pca_2d_projection(pca,?X,?y)
plt.show()

? 2025. All Rights Reserved. 滬ICP備2023009024號-1