为了账号安全,请及时绑定邮箱和手机立即绑定

训练深度学习模型时出错

训练深度学习模型时出错

HUX布斯 2023-03-22 10:54:01
所以我设计了一个 CNN 并使用以下参数进行编译,training_file_loc = "8-SignLanguageMNIST/sign_mnist_train.csv"testing_file_loc = "8-SignLanguageMNIST/sign_mnist_test.csv"def getData(filename):    images = []    labels = []    with open(filename) as csv_file:        file = csv.reader(csv_file, delimiter = ",")        next(file, None)                for row in file:            label = row[0]            data = row[1:]            img = np.array(data).reshape(28,28)                        images.append(img)            labels.append(label)                images = np.array(images).astype("float64")        labels = np.array(labels).astype("float64")            return images, labelstraining_images, training_labels = getData(training_file_loc)testing_images, testing_labels = getData(testing_file_loc)print(training_images.shape, training_labels.shape)print(testing_images.shape, testing_labels.shape)training_images = np.expand_dims(training_images, axis = 3)testing_images = np.expand_dims(testing_images, axis = 3)training_datagen = ImageDataGenerator(    rescale = 1/255,    rotation_range = 45,    width_shift_range = 0.2,    height_shift_range = 0.2,    shear_range = 0.2,    zoom_range = 0.2,    horizontal_flip = True,    fill_mode = "nearest")training_generator = training_datagen.flow(    training_images,    training_labels,    batch_size = 64,)validation_datagen = ImageDataGenerator(    rescale = 1/255,    rotation_range = 45,    width_shift_range = 0.2,    height_shift_range = 0.2,    shear_range = 0.2,    zoom_range = 0.2,    horizontal_flip = True,    fill_mode = "nearest")validation_generator = training_datagen.flow(    testing_images,    testing_labels,    batch_size = 64,])但是,当我运行 model.fit() 时,出现以下错误,ValueError: Shapes (None, 1) and (None, 24) are incompatible将损失函数更改为 后sparse_categorical_crossentropy,程序运行良好。我不明白为什么会这样。谁能解释这一点以及这些损失函数之间的区别?
查看完整描述

2 回答

?
largeQ

TA贡献2039条经验 获得超7个赞

问题是,categorical_crossentropy期望单热编码标签,这意味着,对于每个样本,它期望一个长度张量,num_classes其中label第 th 个元素设置为 1,其他所有元素都为 0。


另一方面,sparse_categorical_crossentropy直接使用整数标签(因为这里的用例是大量的类,所以单热编码标签会浪费大量零的内存)。我相信,但我无法证实这一点,它categorical_crossentropy比它的稀疏对应物运行得更快。


对于您的情况,对于 26 个类,我建议使用非稀疏版本并将您的标签转换为单热编码,如下所示:


def getData(filename):

    images = []

    labels = []

    with open(filename) as csv_file:

        file = csv.reader(csv_file, delimiter = ",")

        next(file, None)

        

        for row in file:

            label = row[0]

            data = row[1:]

            img = np.array(data).reshape(28,28)

            

            images.append(img)

            labels.append(label)

        

        images = np.array(images).astype("float64")

        labels = np.array(labels).astype("float64")

        

    return images, tf.keras.utils.to_categorical(labels, num_classes=26) # you can omit num_classes to have it computed from the data

旁注:除非你有理由使用float64图像,否则我会切换到float32(它将数据集所需的内存减半,并且模型可能会将它们转换为float32第一个操作)


查看完整回答
反对 回复 2023-03-22
?
BIG阳

TA贡献1859条经验 获得超6个赞

很简单,对于输出类为整数的分类问题,使用 sparse_categorical_crosentropy,对于标签在一个热编码标签中转换的问题,我们使用 categorical_crosentropy。



查看完整回答
反对 回复 2023-03-22
  • 2 回答
  • 0 关注
  • 136 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信