我正在 TensorFlow 中开展一个 VAE 项目,其中编码器/解码器网络内置于函数中。这个想法是能够保存,然后加载训练好的模型并使用编码器功能进行采样。恢复模型后,我无法运行解码器功能并将恢复的训练变量返回给我,出现“未初始化值”错误。我认为这是因为该函数要么创建一个新函数,要么覆盖现有函数,要么以其他方式。但我无法弄清楚如何解决这个问题。这是一些代码:class VAE(object): def __init__(self, restore=True): self.session = tf.Session() if restore: self.restore_model() self.build_decoder = tf.make_template('decoder', self._build_decoder)@staticmethoddef _build_decoder(z, output_size=768, hidden_size=200, hidden_activation=tf.nn.elu, output_activation=tf.nn.sigmoid): x = tf.layers.dense(z, hidden_size, activation=hidden_activation) x = tf.layers.dense(x, hidden_size, activation=hidden_activation) logits = tf.layers.dense(x, output_size, activation=output_activation) return distributions.Independent(distributions.Bernoulli(logits), 2)def sample_decoder(self, n_samples): prior = self.build_prior(self.latent_dim) samples = self.build_decoder(prior.sample(n_samples), self.input_size).mean() return self.session.run([samples])def restore_model(self): print("Restoring") self.saver = tf.train.import_meta_graph(os.path.join(self.save_dir, "turbolearn.meta")) self.saver.restore(self.sess, tf.train.latest_checkpoint(self.save_dir)) self._restored = True想跑 samples = vae.sample_decoder(5)在我的训练程序中,我运行: if self.checkpoint: self.saver.save(self.session, os.path.join(self.save_dir, "myvae"), write_meta_graph=True)更新根据下面的建议答案,我更改了恢复方法self.saver = tf.train.Saver()self.saver.restore(self.session, tf.train.latest_checkpoint(self.save_dir))但是现在在创建 Saver() 对象时出现值错误:ValueError: No variables to save
添加回答
举报
0/150
提交
取消