为了账号安全,请及时绑定邮箱和手机立即绑定

如何在 numpy 数组中配置数据标签以训练 Keras 模型?

如何在 numpy 数组中配置数据标签以训练 Keras 模型?

白板的微信 2021-09-23 09:10:12
我正在尝试第一次实施 Keras(对于这个愚蠢的问题很抱歉)作为一个更广泛的项目的一部分,以制作一个学习连接 4 的 AI。作为其中的一部分,我通过一个 NN 一个 6*7 的网格它输出一个包含 7 个值的数组,给出游戏中每一列的选择概率。以下是 Model.summary() 方法的输出以获得更多详细信息:______________________________________________________________Layer (type)                 Output Shape              Param #   =================================================================flatten (Flatten)            (None, 42)                0         _________________________________________________________________dense (Dense)                (None, 20)                860       _________________________________________________________________dense_1 (Dense)              (None, 20)                420       _________________________________________________________________dense_2 (Dense)              (None, 7)                 147       =================================================================Total params: 1,427Trainable params: 1,427Non-trainable params: 0__________________________________________________________________________________________________________________________________当我传递形状为 (1, 6, 7) 的 numpy 数组时,该模型将给出(目前是随机的)预测,但是,当我尝试使用形状为 (221, 6, 7) 的数组训练模型时,数据和标签的形状数组 (221, 7) 我收到此错误:ValueError:检查目标时出错:预期dense_2具有形状(1,)但得到形状为(7,)的数组这是我用来训练模型的代码(输出 (221, 6, 7) 和 (221, 7)):board_tensor = np.array(full_board_list)print(board_tensor.shape)label_tensor = np.array(full_label_list)print(label_tensor.shape)self.model.fit(board_tensor, label_tensor)(该模型是 AI 对象的一部分,因此可以与其他类型的 AI 对象进行比较)这是成功预测一批大小为 1 的代码,由表示板的二维列表生成(它输出 (1 , 6, 7) 和 (1, 7)):input_tensor = np.array(board.board)input_tensor = np.expand_dims(input_tensor, 0)print(input_tensor.shape)probability_distribution = self.model.predict(input_tensor)print(probability_distribution.shape)我意识到该错误可能是由于我对 Keras 中期望提供的方法缺乏了解;所以作为一个小小的旁注,有没有人有任何好的,彻底的学习资源,真正让你了解每种方法在做什么(即,不仅仅是告诉你输入哪些代码来制作图像识别器),这将是可以理解的对于像我这样的 Keras 和 Tensorflow 新手?
查看完整描述

1 回答

?
SMILET

TA贡献1796条经验 获得超4个赞

您正在使用sparse_categorical_crossentropy损失,它采用整数标签(不是单热编码的),而您的标签是单热编码的。这就是您收到错误的原因。

修复它的最简单方法是将 loss 更改为categorical_crossentropy.



查看完整回答
反对 回复 2021-09-23
  • 1 回答
  • 0 关注
  • 199 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
微信客服

购课补贴
联系客服咨询优惠详情

帮助反馈 APP下载

慕课网APP
您的移动学习伙伴

公众号

扫描二维码
关注慕课网微信公众号