1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
| print(__doc__)
import matplotlib.pyplot as plt import numpy as np from sklearn import svm, datasets from sklearn.metrics import precision_recall_curve from sklearn.metrics import average_precision_score from sklearn.preprocessing import label_binarize from sklearn.multiclass import OneVsRestClassifier
from sklearn.model_selection import train_test_split
iris = datasets.load_iris() X = iris.data y = iris.target
y = label_binarize(y, classes=[0, 1, 2]) n_classes = y.shape[1] print (y)
random_state = np.random.RandomState(0) n_samples, n_features = X.shape
X = np.c_[X, random_state.randn(n_samples, 200 * n_features)]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.5, random_state=random_state)
classifier = OneVsRestClassifier(svm.SVC(kernel='linear', probability=True, random_state=random_state))
y_score = classifier.fit(X_train, y_train).decision_function(X_test)
precision = dict() recall = dict() average_precision = dict() for i in range(n_classes): precision[i], recall[i], _ = precision_recall_curve(y_test[:, i], y_score[:, i]) average_precision[i] = average_precision_score(y_test[:, i], y_score[:, i])
precision["micro"], recall["micro"], _ = precision_recall_curve(y_test.ravel(), y_score.ravel()) average_precision["micro"] = average_precision_score(y_test, y_score, average="micro")
plt.clf() plt.plot(recall["micro"], precision["micro"], label='micro-average Precision-recall curve (area = {0:0.2f})'.format(average_precision["micro"])) for i in range(n_classes): plt.plot(recall[i], precision[i], label='Precision-recall curve of class {0} (area = {1:0.2f})'.format(i, average_precision[i]))
plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel('Recall', fontsize=16) plt.ylabel('Precision',fontsize=16) plt.title('Extension of Precision-Recall curve to multi-class',fontsize=16) plt.legend(loc="lower right") plt.show()
|