为了账号安全,请及时绑定邮箱和手机立即绑定

为什么tensorflow SimpleRNNCell的输入是3D的?

为什么tensorflow SimpleRNNCell的输入是3D的?

临摹微笑 2021-12-21 17:43:39
我正在测试 tensorflow tf.keras.layers.SimpleRNNCell。我觉得这太奇怪了。我认为 RNN 单元是接收先前状态a^{<t-1>}和当前数据输入的单元x^{<t>}。它将输出一个新的状态a^{<t>}和当前的 predict \hat{y}^{<t>}。因此,SimpleRNNCell如果设置了 batch_size,则输入应该是 2d。我认为输入应该是[batch_size,feature_size]. 但是,如果输入是 2D,则会引发错误。而之前的状态也需要3D。正确的代码如下:batch_data = tf.ones((batch_size, time_steps, label_num))    simple_rnn_cell = tf.keras.layers.SimpleRNNCell(units)initial_state = tf.zeros((batch_size, time_steps, units))output, rnn_cell_state = simple_rnn_cell(batch_data, initial_state)但是,我认为以下代码是正确的。但我错了batch_data = tf.ones((batch_size, label_num))    simple_rnn_cell = tf.keras.layers.SimpleRNNCell(units)initial_state = tf.zeros((batch_size, units))output, rnn_cell_state = simple_rnn_cell(batch_data, initial_state)所以我的问题是为什么输入SimpleRNNCell是3D?
查看完整描述

2 回答

?
侃侃无极

TA贡献2051条经验 获得超10个赞

第三维是许多特征,用于多元时间序列。在您的情况下,对于特征数使用 1。例如,您可以认为张量 [1,2,3] 是 1D,[[1,2,3]] 是形状为 (1,3) 的 2D, [[[1,2,3]]] 是具有形状 (1,1,3) 等的 3D。

因此,如果我们取一个输入样本,一个变量时间序列将是 [[1,2,3]],但两个变量时间序列可能看起来像 [[1,2,3], [7,8,9]]。


查看完整回答
反对 回复 2021-12-21
?
心有法竹

TA贡献1866条经验 获得超5个赞

RNN(或 LSTM)的输入应该具有 [batch_size, timesteps, nbr_features] 的形状


查看完整回答
反对 回复 2021-12-21
  • 2 回答
  • 0 关注
  • 227 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信