为了账号安全,请及时绑定邮箱和手机立即绑定

Tensorflow 保存多个会话之一

Tensorflow 保存多个会话之一

拉莫斯之舞 2021-09-25 09:37:56
有一个 Python 脚本,我在其中实例化了神经网络类的两个对象。每个对象定义自己的会话并提供保存图形的方法。import tensorflow as tfimport os, shutilclass TestNetwork:    def __init__(self, id):        self.id = id        tf.reset_default_graph()        self.s = tf.placeholder(tf.float32, [None, 2], name='s')        w_initializer, b_initializer = tf.random_normal_initializer(0., 1.0), tf.constant_initializer(0.1)        self.k = tf.layers.dense(self.s, 2, kernel_initializer=w_initializer,                    bias_initializer=b_initializer, name= 'k')        '''Defines self.session and initialize the variables'''        session_conf = tf.ConfigProto(            allow_soft_placement = True,            log_device_placement = False)        self.session = tf.Session(config = session_conf)        self.session.run(tf.global_variables_initializer())    def save_model(self, output_dir):        '''Save the network graph and weights to disk'''        if os.path.exists(output_dir):            # if provided output_dir already exists, remove it            shutil.rmtree(output_dir)        builder = tf.saved_model.builder.SavedModelBuilder(output_dir)        builder.add_meta_graph_and_variables(            self.session,            [tf.saved_model.tag_constants.SERVING],            clear_devices=True)        # create a new directory output_dir and store the saved model in it        builder.save()t1 = TestNetwork(1)t2 = TestNetwork(2)t1.save_model("t1_model")t2.save_model("t2_model")我得到的错误是类型错误:无法将 feed_dict 键解释为张量:名称“save/Const:0”指的是不存在的张量。图中不存在“save/Const”操作。我读到一些说这个错误是由于tf.train.Saver.因此,我在__init__方法的末尾添加了以下行:self.saver = tf.train.Saver(tf.global_variables(), max_to_keep = 5)但是我仍然收到错误。
查看完整描述

1 回答

  • 1 回答
  • 0 关注
  • 192 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信