为了账号安全,请及时绑定邮箱和手机立即绑定

在 keras 中为多类分类生成混淆矩阵

在 keras 中为多类分类生成混淆矩阵

慕妹3146593 2021-09-25 21:33:11
通过训练模型获得高达 98% 的准确率,但混淆矩阵显示出非常高的误分类。我正在使用 keras 在预训练的 VGG16 模型上使用迁移学习方法进行多类分类。问题是使用 CNN 将图像分类为 5 种番茄病害。有 5 个疾病类别,6970 张训练图像和 70 张测试图像。训练模型显示 98.65% 的准确率,而测试显示 94% 的准确率。但问题是当我生成混淆矩阵时,它显示出非常高的误分类。有人请帮助我,是我的代码错误还是模型错误?我很困惑我的模型是否给了我正确的结果。如果有人可以向我解释 keras 如何使用 model.fit_generator 函数实际计算准确度,因为在混淆矩阵上应用准确度的一般公式并没有给我与 keras 计算出的相同的结果。用于测试数据集的代码是:test_generator = test_datagen.flow_from_directory(test_dir,target_size=(150, 150),batch_size=20,class_mode='categorical')test_loss, test_acc = model.evaluate_generator(test_generator, steps=50)print('test acc:', test_acc)我从论坛之一找到了生成混淆矩阵的代码;代码是:import numpy as npfrom sklearn.metrics import confusion_matrix,classification_reportbatch_size = 20num_of_test_samples = 70predictions = model.predict_generator(test_generator,  num_of_test_samples // batch_size+1)y_pred = np.argmax(predictions, axis=1)true_classes = test_generator.classesclass_labels = list(test_generator.class_indices.keys())   print(class_labels)print(confusion_matrix(test_generator.classes, y_pred))report = classification_report(true_classes, y_pred, target_names=class_labels)print(report)以下是我得到的结果:测试精度:Found 70 images belonging to 5 classes.test acc: 0.9420454461466182混淆矩阵的结果:['TEB', 'TH', 'TLB', 'TLM', 'TSL'][[2 3 2 4 3] [4 2 3 0 5] [3 3 3 2 3] [3 3 2 4 2] [2 2 4 4 2]]]              precision    recall  f1-score   support         TEB       0.14      0.14      0.14        14          TH       0.15      0.14      0.15        14         TLB       0.21      0.21      0.21        14         TLM       0.29      0.29      0.29        14         TSL       0.13      0.14      0.14        14   micro avg       0.19      0.19      0.19        70   macro avg       0.19      0.19      0.19        70weighted avg       0.19      0.19      0.19        70
查看完整描述

6 回答

?
慕虎7371278

TA贡献1802条经验 获得超4个赞

测试标签应该是 class_indices 而不是 classes

true_classes = test_generator.class_indices


查看完整回答
反对 回复 2021-09-25
?
狐的传说

TA贡献1804条经验 获得超3个赞

亲爱的总是对任何分类性能参数做以下事情:

  1. 首先重置您在预测中使用的生成器

  2. 将 shuffle 等于 false 在 flow_from_directory()


查看完整回答
反对 回复 2021-09-25
  • 6 回答
  • 0 关注
  • 647 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
微信客服

购课补贴
联系客服咨询优惠详情

帮助反馈 APP下载

慕课网APP
您的移动学习伙伴

公众号

扫描二维码
关注慕课网微信公众号