我有以下代码,它绘制了KNN算法的嵌套与非嵌套交叉验证。# Number of random trialsNUM_TRIALS = 30# Load the datasetX_iris = X.valuesy_iris = y# Set up possible values of parameters to optimize overp_grid = {"n_neighbors": [1, 5, 10]}# We will use a Support Vector Classifier with "rbf" kernelsvm = KNeighborsClassifier()# Arrays to store scoresnon_nested_scores = np.zeros(NUM_TRIALS)nested_scores = np.zeros(NUM_TRIALS)# Loop for each trialfor i in range(NUM_TRIALS): # Choose cross-validation techniques for the inner and outer loops, # independently of the dataset. # E.g "GroupKFold", "LeaveOneOut", "LeaveOneGroupOut", etc. inner_cv = KFold(n_splits=4, shuffle=True, random_state=i) outer_cv = KFold(n_splits=4, shuffle=True, random_state=i) # Non_nested parameter search and scoring clf = GridSearchCV(estimator=svm, param_grid=p_grid, cv=inner_cv) clf.fit(X_iris, y_iris) non_nested_scores[i] = clf.best_score_ # Nested CV with parameter optimization nested_score = cross_val_score(clf, X=X_iris, y=y_iris, cv=outer_cv) nested_scores[i] = nested_score.mean()score_difference = non_nested_scores - nested_scorespreds=clf.best_estimator_.predict(X_test)from sklearn.metrics import confusion_matrixcm = confusion_matrix(y_test, preds)one, two, three, four,five,six,seven,eight,nine = confusion_matrix(y_test, preds).ravel()我遇到的问题是混淆矩阵绘图,我遇到了以下错误:ValueError Traceback (most recent call last)<ipython-input-22-13536688e18b> in <module>() 45 from sklearn.metrics import confusion_matrix 46 cm = confusion_matrix(y_test, preds)---> 47 one, two, three, four,five,six,seven,eight,nine = confusion_matrix(y_test, preds).ravel() 48 cm = [[one,two],[three,four],[five,six],[seven,eight],[nine,eight]] 49 ax= plt.subplot()ValueError: too many values to unpack (expected 9)我不知道如何解决这个问题。我的数据集中有 9 个目标变量,存储在 y 中。[11 11 11 ... 33 33 33] #the target variables being : 11,12,13,21,22,23,31,32,33
1 回答
慕妹3146593
TA贡献1820条经验 获得超9个赞
混淆矩阵由“cm = confusion_matrix(y_test,preds)”构建,其中cm是9x9矩阵(因为目标变量中有9个不同的标签)。如果要绘制它,可以使用plot_confusion_matrix函数。没有必要把它弄得乱七八糟。如果对其进行处理,则 9x9 矩阵将转换为 81 个值,并且您将它解压缩为赋值左侧的 9 个变量。这就是您收到“太多值无法解压缩(预期9)”错误的原因。
添加回答
举报
0/150
提交
取消