为了账号安全,请及时绑定邮箱和手机立即绑定

将 keras/tensorflow h5/json 转换为 tensorflow pb 时遇到问题

将 keras/tensorflow h5/json 转换为 tensorflow pb 时遇到问题

红颜莎娜 2023-09-05 20:32:08
我使用 keras(张量流后端)训练了一个网络,并将模型保存为 json,权重保存为 h5。我现在尝试将其转换为单个张量流 pb 文件,它抱怨输出节点的名称。系统信息:Tensorflow 2.3.0 Keras 2.4.3 Cuda 10.1 Cudnn 7转换脚本非常简单:import jsonfrom tensorflow import kerasfrom keras import backend as Kimport tensorflow as tfjson_file = "my-trained-model.json"h5_file = "my-trained-model.h5"Output_Path = "./trained_models/"Frozen_pb_File = "my-trained-model.pb"def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):       from tensorflow.python.framework.graph_util import convert_variables_to_constants    graph = session.graph    with graph.as_default():        freeze_var_names = list(set(v.op.name for v in tf.compat.v1.global_variables()).difference(keep_var_names or []))        output_names = output_names or []        output_names += [v.op.name for v in tf.compat.v1.global_variables()]        # Graph -> GraphDef ProtoBuf        input_graph_def = graph.as_graph_def()        if clear_devices:            for node in input_graph_def.node:                node.device = ""        frozen_graph = convert_variables_to_constants(session, input_graph_def,                                                      output_names, freeze_var_names)        return frozen_graphwith open(json_file, 'r') as json_file:    model = keras.models.model_from_json(json_file.read())model.load_weights(h5_file)model.summary()# get output node namesOutputNames = [out.op.name for out in model.outputs]print("\nOutput Names:\n", OutputNames)  # this prints "concatenate/concat" as the only output node name# freeze the modelfrozen_graph = freeze_session(tf.compat.v1.keras.backend.get_session(), output_names=OutputNames)# save the output files# this is the .pb file (a binary file)tf.io.write_graph(frozen_graph, Output_Path, Frozen_pb_File, as_text=False)当我运行这个时,AssertionError: concatenate/concat is not in graph因此,由于某种原因,它正在读取“concatenate/concat”的输出节点名称。下面给出模型总结,可以看到输出节点是“concatenate”;但是,即使我将输出节点名称硬编码为“连接”,我也会收到类似的断言错误:AssertionError: concatenate is not in graph
查看完整描述

1 回答

?
繁华开满天机

TA贡献1816条经验 获得超4个赞

看起来这都是由于尝试冻结 TensorFlow 2.3 模型而引起的。显然,Tensorflow 2.0+ 已弃用“冻结”概念,转而采用“保存模型”概念。一旦发现这一点,我就能够立即将 h5/json 保存到已保存的模型 pb 中。


我仍然不确定这种格式是否针对推理进行了优化,所以我将对此进行一些跟进,但由于我的问题是关于我看到的错误,我想我会发布导致问题的原因。


作为参考,这是我的 python 脚本,用于将 keras h5/json 文件转换为 Tensorflow 保存的模型格式。


import os

from keras.models import model_from_json

import tensorflow as tf

import genericpath

from genericpath import *


def splitext(p):

    p = os.fspath(p)

    if isinstance(p, bytes):

        sep = b'/'

        extsep = b'.'

    else:

        sep = '/'

        extsep = '.'

    return genericpath._splitext(p, sep, None, extsep)


def load_model(path,custom_objects={},verbose=0):

    from keras.models import model_from_json


    path = splitext(path)[0]

    with open('%s.json' % path,'r') as json_file:

        model_json = json_file.read()

    model = model_from_json(model_json, custom_objects=custom_objects)

    model.load_weights('%s.h5' % path)

    # if verbose: print 'Loaded from %s' % path

    return model



json_file = "model.json"  # the h5 file should be "model.h5"


model = load_model(json_file) # load the json/h5 pair

model.save('my_saved_model') # this is a directory name to store the saved model


查看完整回答
反对 回复 2023-09-05
  • 1 回答
  • 0 关注
  • 191 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信