为了账号安全,请及时绑定邮箱和手机立即绑定

使用Keras进行整数系列预测

使用Keras进行整数系列预测

慕虎7371278 2021-05-07 14:41:39
我正在尝试编写一个RNN模型,该模型将预测整数序列中的下一个数字。模型损失在每个时期都会变小,但是预测永远不会变得非常准确。我已经尝试了许多火车的大小和时期,但是我的预测值总是与期望值相差几位数。您能否给我一些提示,以改善或我做错了什么?这是代码:from keras.models import Sequentialfrom keras.layers import Dense, Dropout, LSTMfrom keras.callbacks import ModelCheckpointfrom keras.utils import np_utilsfrom keras import metricsimport numpy as nptraining_length = 10000rnn_size = 512hm_epochs = 30def generate_sequence(length=10):    step = np.random.randint(0,50)    first_element = np.random.randint(0,10)    first_element = 0    l_ist = [(first_element + (step*i)) for i in range(length)]    return l_isttraining_set = []for _ in range(training_length):    training_set.append(generate_sequence(10))feature_set = [i[:-1] for i in training_set]label_set = [i[-1:] for i in training_set]X = np.reshape(feature_set,(training_length, 9, 1))y = np.array(label_set)model = Sequential()model.add(LSTM(rnn_size, input_shape = (X.shape[1], X.shape[2]), return_sequences = True))model.add(Dropout(0.2))model.add(LSTM(rnn_size))model.add(Dropout(0.2))model.add(Dense(y.shape[1], activation='linear'))model.compile(loss='mse', optimizer='rmsprop', metrics=['accuracy'])filepath="checkpoint_folder/weights-improvement.hdf5"checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=True, mode='min')callbacks_list = [checkpoint]model.fit(X,y,epochs=hm_epochs, callbacks=callbacks_list)效果:30个纪元后(亏损:66.39):1顺序:[0,20,40,60,80,100,120,140,160]预期:[180] || 得到了:[181.86118]2顺序:[0,11,22,33,44,55,66,77,88]预期:[99] || 得到了:[102.17369]3顺序:[0,47,94,141,188,235,282,329,376]预计:[423] || 得到了:[419.1763]4顺序:[0,47,94,141,188,235,282,329,376]预期:[423] || 得到了:[419.1763]5序列:[0,4,8,12,16,20,24,28,32]预期:[36] || 得到了:[37.506496]6序列:[0,48,96,144,192,240,288,336,384]预期:[432] || 得到了:[425.0569]7顺序:[0、28、56、84、112、140、168、196、224]预期:[252] || 得到了:[253.60233]8顺序:[0、18、36、54、72、90、108、126、144]预期:[162] || 得到了:[163.538]9顺序:[0,19,38,57,76,95,114,133,152]预期:[171] || 得到了:[173.77933]10序列:[0,1,2,3,4,5,6,7,8]预期:[9] || 得到了:[9.577981]...
查看完整描述

2 回答

?
回首忆惘然

TA贡献1847条经验 获得超11个赞

您是否尝试了更长的顺序?不需要LSTM,因为依赖性不是很长。您可以尝试使用RNN的另一个变体。


查看完整回答
反对 回复 2021-05-11
  • 2 回答
  • 0 关注
  • 199 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信