我有 3 类课程Tree, Stump, Ground。我为这些类别列出了一个清单:CATEGORIES = ["Tree", "Stump", "Ground"]当我打印我的预测时,它给了我输出[[0. 1. 0.]]我已经阅读了 numpy 的 Argmax,但我不完全确定在这种情况下如何使用它。我试过使用print(np.argmax(prediction))但这给了我1. 太好了,但我想找出索引是什么1,然后打印出类别而不是最高值。import cv2import tensorflow as tfimport numpy as npCATEGORIES = ["Tree", "Stump", "Ground"]def prepare(filepath): IMG_SIZE = 150 # This value must be the same as the value in Part1 img_array = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE) new_array = cv2.resize(img_array, (IMG_SIZE, IMG_SIZE)) return new_array.reshape(-1, IMG_SIZE, IMG_SIZE, 1)# Able to load a .model, .h3, .chibai and even .dogmodel = tf.keras.models.load_model("models/test.model")prediction = model.predict([prepare('image.jpg')])print("Predictions:")print(prediction)print(np.argmax(prediction))我希望我的预测告诉我:Predictions:[[0. 1. 0.]]Stump
2 回答
慕无忌1623718
TA贡献1744条经验 获得超4个赞
您只需要使用以下结果对类别进行索引np.argmax:
pred_name = CATEGORIES[np.argmax(prediction)]
print(pred_name)
添加回答
举报
0/150
提交
取消