我以某种模式创建了一个人工点集合来运行 2D 分类器。因此,我插入点,例如 (x1,x2) 并命名它们的正确类(标签 1 或标签 2)。点 x_train 和 y_train 都放入 Keras 层模型,最后,我运行 Model.fit 方法。# Assign returned datax_train, y_train = separate_dots_from_dict(dots)y_train = to_categorical(y_train, NUM_CLASSES)print("Shapes (x, y):", x_train.shape, ",", y_train.shape)# Classificationmodel = Sequential()model.add(Dense(NUM_CLASSES * 8, input_shape = (2, 1, 1), activation = 'relu'))model.add(Dense(NUM_CLASSES * 4, activation = 'relu'))model.add(Dense(NUM_CLASSES, activation = 'softmax'))model.compile(loss = 'categorical_crossentropy', optimizer = 'sgd', metrics = ['accuracy'])model.fit(x_train, y_train, epochs = 4, batch_size = 2)之前,我已经打印了我的点转换结果,这些结果是从我的 separator_dots_from_dict() 函数成功输出的,并且我已经使用 Keras 包中的 to_categorical() 方法进行了转换。我的功能以return np.array(x_train).reshape(len(x_train), 2, 1, 1), np.array(y_train).reshape(len(y_train))在下面,我将向您展示在分类开始之前最终生成的 5 个虚构点:X[[[[ 0.5]] [[ 0.8]]] [[[ 0.3]] [[ 0.6]]] [[[ 0.1]] [[-0.3]]][[[ 1.1]] [[-1.1]]] [[[-1.4]] [[-1.5]]]]是[[1. 0.] [1. 0.] [1. 0.] [0. 1.] [0. 1.]]Y 是 y_train 所以它是训练目标,例如标签。x_train (X) 的格式可能看起来很笨拙,但考虑到我刚刚在这里类似地实现的 MNIST 图像的重塑,这正是著名的格式。不幸的是,我收到以下错误:Using TensorFlow backend.Shapes (x, y): (34, 2, 1, 1) , (34, 2)Traceback (most recent call last): File "main.py", line 88, in <module> model.fit(x_train, y_train, epochs = 4, batch_size = 2) File "/home/scud3r1a/Conda/envs/numtenpy/lib/python3.6/site-packages/keras/engine/training.py", line 950, in fit batch_size=batch_size) File "/home/scud3r1a/Conda/envs/numtenpy/lib/python3.6/site-packages/keras/engine/training.py", line 787, in _standardize_user_data我能找到的所有解决方案都有解决方案,只需更改最后一个 Dense 层中的单位即可。但首先,这不会影响任何事情,其次,我认为这是真的。尺寸误差与 x_train 形状成比例。在这里做什么?
2 回答
data:image/s3,"s3://crabby-images/ed21a/ed21a8404de8ccae7f5f0c8138e89bdda46ba15c" alt="?"
喵喔喔
TA贡献1735条经验 获得超5个赞
该Dense层需要 dims 的输入(input_dims, None),您要发送的 dims 3,这应该只是1预期的(正确格式)。None代表batch_size不需要定义的 。
在您的模型中尝试此更改:
x_train = x_train.reshape(2,-1)
model = Sequential()
model.add(Dense(NUM_CLASSES * 8, input_dim=(2,), activation = 'relu')
它会解决你的问题。
添加回答
举报
0/150
提交
取消