为了账号安全,请及时绑定邮箱和手机立即绑定

Keras AttributeError: 'NoneType' 对象在 load_model

Keras AttributeError: 'NoneType' 对象在 load_model

慕桂英3389331 2023-02-15 17:21:49
我正在处理课程作业,我必须在 keras 中保存和加载模型。我创建模型、训练模型和保存模型的代码是def get_new_model(input_shape):    """    This function should build a Sequential model according to the above specification. Ensure the     weights are initialised by providing the input_shape argument in the first layer, given by the    function argument.    Your function should also compile the model with the Adam optimiser, sparse categorical cross    entropy loss function, and a single accuracy metric.    """        model = Sequential([        Conv2D(16, kernel_size=(3,3),activation='relu',padding='Same', name='conv_1', input_shape=input_shape),        Conv2D(8, kernel_size=(3,3), activation='relu', padding='Same', name='conv_2'),        MaxPooling2D(pool_size=(8,8), name='pool_1'),        tf.keras.layers.Flatten(name='flatten'),        Dense(32, activation='relu', name='dense_1'),        Dense(10, activation='softmax', name='dense_2')    ])    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['acc'])    return modelmodel = get_new_model(x_train[0].shape)def get_checkpoint_every_epoch():    """    This function should return a ModelCheckpoint object that:    - saves the weights only at the end of every epoch    - saves into a directory called 'checkpoints_every_epoch' inside the current working directory    - generates filenames in that directory like 'checkpoint_XXX' where      XXX is the epoch number formatted to have three digits, e.g. 001, 002, 003, etc.    """    path = 'checkpoints_every_epoch/checkpoint_{epoch:02d}'    checkpoint = ModelCheckpoint(filepath = path, save_weights_only=True, save_freq= 'epoch')    return checkpoint
查看完整描述

1 回答

?
跃然一笑

TA贡献1826条经验 获得超6个赞

我得到了它。文件路径名中有错误。我花了很多时间来弄清楚。所以正确的功能是


def get_model_last_epoch(model):

    """

    This function should create a new instance of the CNN you created earlier,

    load on the weights from the last training epoch, and return this model.

    """

    model.load_weights(tf.train.latest_checkpoint('checkpoints_every_epoch'))

    return model

    

    

def get_model_best_epoch(model):

    """

    This function should create a new instance of the CNN you created earlier, load 

    on the weights leading to the highest validation accuracy, and return this model.

    """

    #filepath = tf.train.latest_checkpoint('checkpoints_best_only')

    model.load_weights(tf.train.latest_checkpoint('checkpoints_best_only'))

    return model

    

它不会给出错误,因为文件名tf.train.latest_checkpoint是正确的


查看完整回答
反对 回复 2023-02-15
  • 1 回答
  • 0 关注
  • 75 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信