我正在测试 tensorflow tf.keras.layers.SimpleRNNCell。我觉得这太奇怪了。我认为 RNN 单元是接收先前状态a^{<t-1>}和当前数据输入的单元x^{<t>}。它将输出一个新的状态a^{<t>}和当前的 predict \hat{y}^{<t>}。因此,SimpleRNNCell如果设置了 batch_size,则输入应该是 2d。我认为输入应该是[batch_size,feature_size]. 但是,如果输入是 2D,则会引发错误。而之前的状态也需要3D。正确的代码如下:batch_data = tf.ones((batch_size, time_steps, label_num)) simple_rnn_cell = tf.keras.layers.SimpleRNNCell(units)initial_state = tf.zeros((batch_size, time_steps, units))output, rnn_cell_state = simple_rnn_cell(batch_data, initial_state)但是,我认为以下代码是正确的。但我错了batch_data = tf.ones((batch_size, label_num)) simple_rnn_cell = tf.keras.layers.SimpleRNNCell(units)initial_state = tf.zeros((batch_size, units))output, rnn_cell_state = simple_rnn_cell(batch_data, initial_state)所以我的问题是为什么输入SimpleRNNCell是3D?
添加回答
举报
0/150
提交
取消