为了账号安全,请及时绑定邮箱和手机立即绑定

LSTM 模型的问题

LSTM 模型的问题

慕尼黑的夜晚无繁华 2021-09-24 14:45:52
我尝试在 PyTorch 中实现 LSTM 模型并遇到这样的问题:损失不减少。我的任务是这样的:我有不同功能的会话。会话长度是固定的,等于 20。我的目标是预测最后一个会话是否被跳过。我试图缩放输入特征,我试图传递target给特征(也许提供的特征绝对没有信息,我认为这应该导致过度拟合并且损失应该接近 0),但我的损失减少总是这样的: print(X.shape)#(82770, 20, 31) where 82770 is count of sessions, 20 is seq_len, 31 is count of featuresprint(y.shape)#(82770, 20)我也定义了get_batches函数。是的,我知道这个生成器中最后一批的问题def get_batches(X, y, batch_size):'''Create a generator that returns batches of size   batch_size x seq_length from arr.'''assert X.shape[0] == y.shape[0]assert X.shape[1] == y.shape[1]assert len(X.shape) == 3assert len(y.shape) == 2seq_len = X.shape[1]n_batches = X.shape[0]//seq_lenfor batch_number in range(n_batches):    #print(batch_number*batch_size, )    batch_x = X[batch_number*batch_size:(batch_number+1)*batch_size, :, :]    batch_y = y[batch_number*batch_size:(batch_number+1)*batch_size, :]    if batch_x.shape[0] == batch_size:        yield batch_x, batch_y    else:        print('batch_x shape: {}'.format(batch_x.shape))        break
查看完整描述

1 回答

?
紫衣仙女

TA贡献1839条经验 获得超15个赞

我的失败,忘记缩放输入功能,现在工作正常。


查看完整回答
反对 回复 2021-09-24
  • 1 回答
  • 0 关注
  • 220 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信