为了账号安全,请及时绑定邮箱和手机立即绑定

在 TensorFlow 中保存和恢复函数

在 TensorFlow 中保存和恢复函数

三国纷争 2021-08-14 17:20:22
我正在 TensorFlow 中开展一个 VAE 项目,其中编码器/解码器网络内置于函数中。这个想法是能够保存,然后加载训练好的模型并使用编码器功能进行采样。恢复模型后,我无法运行解码器功能并将恢复的训练变量返回给我,出现“未初始化值”错误。我认为这是因为该函数要么创建一个新函数,要么覆盖现有函数,要么以其他方式。但我无法弄清楚如何解决这个问题。这是一些代码:class VAE(object):        def __init__(self, restore=True):        self.session = tf.Session()        if restore:            self.restore_model()            self.build_decoder = tf.make_template('decoder', self._build_decoder)@staticmethoddef _build_decoder(z, output_size=768, hidden_size=200,                  hidden_activation=tf.nn.elu, output_activation=tf.nn.sigmoid):    x = tf.layers.dense(z, hidden_size, activation=hidden_activation)    x = tf.layers.dense(x, hidden_size, activation=hidden_activation)    logits = tf.layers.dense(x, output_size, activation=output_activation)    return distributions.Independent(distributions.Bernoulli(logits), 2)def sample_decoder(self, n_samples):    prior = self.build_prior(self.latent_dim)    samples = self.build_decoder(prior.sample(n_samples), self.input_size).mean()    return self.session.run([samples])def restore_model(self):    print("Restoring")    self.saver = tf.train.import_meta_graph(os.path.join(self.save_dir, "turbolearn.meta"))    self.saver.restore(self.sess, tf.train.latest_checkpoint(self.save_dir))    self._restored = True想跑 samples = vae.sample_decoder(5)在我的训练程序中,我运行:        if self.checkpoint:            self.saver.save(self.session, os.path.join(self.save_dir, "myvae"), write_meta_graph=True)更新根据下面的建议答案,我更改了恢复方法self.saver = tf.train.Saver()self.saver.restore(self.session, tf.train.latest_checkpoint(self.save_dir))但是现在在创建 Saver() 对象时出现值错误:ValueError: No variables to save
查看完整描述

1 回答

  • 1 回答
  • 0 关注
  • 325 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
微信客服

购课补贴
联系客服咨询优惠详情

帮助反馈 APP下载

慕课网APP
您的移动学习伙伴

公众号

扫描二维码
关注慕课网微信公众号