为了账号安全,请及时绑定邮箱和手机立即绑定

Keras 模型未能减少损失

Keras 模型未能减少损失

Go
慕莱坞森 2022-06-14 17:43:39
我提出了一个tf.keras模型无法从非常简单的数据中学习的示例。我正在使用tensorflow-gpu==2.0.0和keras==2.3.0Python 3.7。在文章的最后,我给出了 Python 代码来重现我观察到的问题。数据样本是形状为 (6, 16, 16, 16, 3) 的 Numpy 数组。为了使事情变得非常简单,我只考虑充满 1 和 0 的数组。带有 1 的数组被赋予标签 1,带有 0 的数组被赋予标签 0。我可以使用以下n_samples = 240代码生成一些样本(在下面,):def generate_fake_data():    for j in range(1, 240 + 1):        if j < 120:            yield np.ones((6, 16, 16, 16, 3)), np.array([0., 1.])        else:            yield np.zeros((6, 16, 16, 16, 3)), np.array([1., 0.])为了在模型中输入这些数据,我使用下面的代码tf.keras创建了一个实例。tf.data.Dataset这将基本上创建洗牌批次的BATCH_SIZE = 12样本。def make_tfdataset(for_training=True):    dataset = tf.data.Dataset.from_generator(generator=lambda: generate_fake_data(),                                             output_types=(tf.float32,                                                           tf.float32),                                             output_shapes=(tf.TensorShape([6, 16, 16, 16, 3]),                                                            tf.TensorShape([2])))    dataset = dataset.repeat()    if for_training:        dataset = dataset.shuffle(buffer_size=1000)    dataset = dataset.batch(BATCH_SIZE)    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)    return dataset问题!在 500 个 epoch 中,模型损失保持在 0.69 左右,并且从未低于 0.69。如果我将学习率设置为1e-2而不是1e-3. 数据非常简单(只有 0 和 1)。天真地,我希望该模型具有比仅 0.6 更好的精度。事实上,我希望它能够迅速达到 100% 的准确率。我做错了什么?
查看完整描述

2 回答

?
拉丁的传说

TA贡献1789条经验 获得超8个赞

您的代码有一个关键问题:维度洗牌您永远不应该触及的一个维度是批处理维度- 因为根据定义,它包含数据的独立样本。在您的第一次重塑中,您将特征尺寸与批量尺寸混合:

Tensor("input_1:0", shape=(12, 6, 16, 16, 16, 3), dtype=float32)
Tensor("lambda/Reshape:0", shape=(72, 16, 16, 16, 3), dtype=float32)

这就像喂食 72 个独立的形状样本(16,16,16,3)。其他层也有类似的问题。



解决方案


  • 与其重塑每一步(你应该使用它Reshape),不如塑造你现有的 Conv 和池化层,让一切都直接进行。

  • 除了输入和输出层,最好给每一层命名简短而简单 - 不会失去清晰度,因为每一行都由层名称明确定义

  • GlobalAveragePooling旨在成为最后一层,因为它会折叠特征尺寸- 在您的情况下,如下所示(12,16,16,16,3) --> (12,3):之后的转换几乎没有用

  • 根据上述,我替换Conv1DConv3D

  • 除非您使用可变批量大小,否则请始终使用batch_shape=vs. shape=,因为您可以全面检查图层尺寸(非常有帮助)

  • 您的真实值batch_size是 6,从您的评论回复中推断出来

  • kernel_size=1并且(尤其是)filters=1是一个非常弱的卷积,我相应地替换了它 - 如果你愿意,你可以恢复

  • 如果您的预期应用程序中只有 2 个类,我建议您Dense(1, 'sigmoid')使用binary_crossentropy损失

最后一点:除了维度改组建议之外,您可以将上述所有内容都扔掉,仍然可以获得完美的训练集性能;这是问题的根源。

def create_model(batch_size, input_shape):


    ipt = Input(batch_shape=(batch_size, *input_shape))

    x   = Conv3D(filters=64, kernel_size=8, strides=(2, 2, 2),

                             activation='relu', padding='same')(ipt)

    x   = Conv3D(filters=8,  kernel_size=4, strides=(2, 2, 2),

                             activation='relu', padding='same')(x)

    x   = GlobalAveragePooling3D()(x)

    out = Dense(units=2, activation='softmax')(x)


    return Model(inputs=ipt, outputs=out)

BATCH_SIZE = 6

INPUT_SHAPE = (16, 16, 16, 3)

BATCH_SHAPE = (BATCH_SIZE, *INPUT_SHAPE)


def generate_fake_data():

    for j in range(1, 240 + 1):

        if j < 120:

            yield np.ones(INPUT_SHAPE), np.array([0., 1.])

        else:

            yield np.zeros(INPUT_SHAPE), np.array([1., 0.])



def make_tfdataset(for_training=True):

    dataset = tf.data.Dataset.from_generator(generator=lambda: generate_fake_data(),

                                 output_types=(tf.float32,

                                               tf.float32),

                                 output_shapes=(tf.TensorShape(INPUT_SHAPE),

                                                tf.TensorShape([2])))

    dataset = dataset.repeat()

    if for_training:

        dataset = dataset.shuffle(buffer_size=1000)

    dataset = dataset.batch(BATCH_SIZE)

    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

    return dataset

结果:


Epoch 28/500

40/40 [==============================] - 0s 3ms/step - loss: 0.0808 - acc: 1.0000


查看完整回答
反对 回复 2022-06-14
?
小怪兽爱吃肉

TA贡献1852条经验 获得超1个赞

由于您的标签可以是 0 或 1,我建议将激活函数更改为softmax,将输出神经元的数量更改为 2。现在,最后一层(输出)将如下所示:

out = Dense(units=2, activation='softmax')(reshaped_conv_features)

我之前也遇到过同样的问题,并发现由于 1 或 0 的概率是相关的,从某种意义上说,它不是一个多标签分类问题,Softmax 是一个更好的选择。Sigmoid 分配概率而不考虑其他可能的输出标签。


查看完整回答
反对 回复 2022-06-14
  • 2 回答
  • 0 关注
  • 137 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信