循环神经网络(RNN)实现股票预测-深度学习100例 | 第9天
发布日期:2021-07-01 04:21:01
浏览次数:2
分类:技术文章
本文共 7612 字,大约阅读时间需要 25 分钟。
文章目录
一、前言
今天是第9天,我们将开始RNN系列,完成股票开盘价格的预测,最后的R2可达到0.72
,CNN系列后续我也会穿插更新
我的环境:
- 语言环境:Python3.6.5
- 编译器:jupyter notebook
- 深度学习环境:TensorFlow2.4.1
往期精彩内容:
来自专栏:
转载请通过左侧联系方式(电脑端可看)联系我,备注:CSDN转载
二、RNN是什么
传统神经网络的结构比较简单:输入层 – 隐藏层 – 输出层
RNN 跟传统神经网络最大的区别在于每次都会将前一次的输出结果,带到下一次的隐藏层中,一起训练。如下图所示:
这里用一个具体的案例来看看 RNN 是如何工作的:
用户说了一句“what time is it?”,我们的神经网络会先将这句话分为五个基本单元(四个单词+一个问号)
然后,按照顺序将五个基本单元输入RNN网络,先将 “what”作为RNN的输入,得到输出01
随后,按照顺序将“time”输入到RNN网络,得到输出02
。
这个过程我们可以看到,输入 “time” 的时候,前面“what” 的输出也会对02
的输出产生了影响(隐藏层中有一半是黑色的)。
以此类推,我们可以看到,前面所有的输入产生的结果都对后续的输出产生了影响(可以看到圆形中包含了前面所有的颜色)
当神经网络判断意图的时候,只需要最后一层的输出05
,如下图所示:
三、准备工作
1.设置GPU
如果使用的是CPU可以注释掉这部分的代码。
import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus: tf.config.experimental.set_memory_growth(gpus[0], True) #设置GPU显存用量按需使用 tf.config.set_visible_devices([gpus[0]],"GPU")
2.加载数据
import os,mathfrom tensorflow.keras.layers import Dropout, Dense, SimpleRNNfrom sklearn.preprocessing import MinMaxScalerfrom sklearn import metricsimport numpy as npimport pandas as pdimport tensorflow as tfimport matplotlib.pyplot as plt# 支持中文plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
data = pd.read_csv('./datasets/SH600519.csv') # 读取股票文件data
Unnamed: 0 | date | open | close | high | low | volume | code | |
---|---|---|---|---|---|---|---|---|
0 | 74 | 2010-04-26 | 88.702 | 87.381 | 89.072 | 87.362 | 107036.13 | 600519 |
1 | 75 | 2010-04-27 | 87.355 | 84.841 | 87.355 | 84.681 | 58234.48 | 600519 |
2 | 76 | 2010-04-28 | 84.235 | 84.318 | 85.128 | 83.597 | 26287.43 | 600519 |
3 | 77 | 2010-04-29 | 84.592 | 85.671 | 86.315 | 84.592 | 34501.20 | 600519 |
4 | 78 | 2010-04-30 | 83.871 | 82.340 | 83.871 | 81.523 | 85566.70 | 600519 |
... | ... | ... | ... | ... | ... | ... | ... | ... |
2421 | 2495 | 2020-04-20 | 1221.000 | 1227.300 | 1231.500 | 1216.800 | 24239.00 | 600519 |
2422 | 2496 | 2020-04-21 | 1221.020 | 1200.000 | 1223.990 | 1193.000 | 29224.00 | 600519 |
2423 | 2497 | 2020-04-22 | 1206.000 | 1244.500 | 1249.500 | 1202.220 | 44035.00 | 600519 |
2424 | 2498 | 2020-04-23 | 1250.000 | 1252.260 | 1265.680 | 1247.770 | 26899.00 | 600519 |
2425 | 2499 | 2020-04-24 | 1248.000 | 1250.560 | 1259.890 | 1235.180 | 19122.00 | 600519 |
2426 rows × 8 columns
"""前(2426-300=2126)天的开盘价作为训练集,表格从0开始计数,2:3 是提取[2:3)列,前闭后开,故提取出C列开盘价后300天的开盘价作为测试集"""training_set = data.iloc[0:2426 - 300, 2:3].values test_set = data.iloc[2426 - 300:, 2:3].values
四、数据预处理
1.归一化
sc = MinMaxScaler(feature_range=(0, 1))training_set = sc.fit_transform(training_set)test_set = sc.transform(test_set)
2.设置测试集训练集
x_train = []y_train = []x_test = []y_test = []"""使用前60天的开盘价作为输入特征x_train 第61天的开盘价作为输入标签y_train for循环共构建2426-300-60=2066组训练数据。 共构建300-60=260组测试数据"""for i in range(60, len(training_set)): x_train.append(training_set[i - 60:i, 0]) y_train.append(training_set[i, 0]) for i in range(60, len(test_set)): x_test.append(test_set[i - 60:i, 0]) y_test.append(test_set[i, 0]) # 对训练集进行打乱np.random.seed(7)np.random.shuffle(x_train)np.random.seed(7)np.random.shuffle(y_train)tf.random.set_seed(7)
"""将训练数据调整为数组(array)调整后的形状:x_train:(2066, 60, 1)y_train:(2066,)x_test :(240, 60, 1)y_test :(240,)"""x_train, y_train = np.array(x_train), np.array(y_train) # x_train形状为:(2066, 60, 1)x_test, y_test = np.array(x_test), np.array(y_test)"""输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]"""x_train = np.reshape(x_train, (x_train.shape[0], 60, 1))x_test = np.reshape(x_test, (x_test.shape[0], 60, 1))
五、构建模型
model = tf.keras.Sequential([ SimpleRNN(100, return_sequences=True), #布尔值。是返回输出序列中的最后一个输出,还是全部序列。 Dropout(0.1), #防止过拟合 SimpleRNN(100), Dropout(0.1), Dense(1)])
六、激活模型
# 该应用只观测loss数值,不观测准确率,所以删去metrics选项,一会在每个epoch迭代显示时只显示loss值model.compile(optimizer=tf.keras.optimizers.Adam(0.001), loss='mean_squared_error') # 损失函数用均方误差
七、训练模型
history = model.fit(x_train, y_train, batch_size=64, epochs=20, validation_data=(x_test, y_test), validation_freq=1) #测试的epoch间隔数model.summary()
Epoch 1/2033/33 [==============================] - 6s 123ms/step - loss: 0.1809 - val_loss: 0.0310Epoch 2/2033/33 [==============================] - 3s 105ms/step - loss: 0.0257 - val_loss: 0.0721Epoch 3/2033/33 [==============================] - 3s 85ms/step - loss: 0.0165 - val_loss: 0.0059Epoch 4/2033/33 [==============================] - 3s 85ms/step - loss: 0.0097 - val_loss: 0.0111Epoch 5/2033/33 [==============================] - 3s 90ms/step - loss: 0.0099 - val_loss: 0.0139Epoch 6/2033/33 [==============================] - 3s 105ms/step - loss: 0.0067 - val_loss: 0.0167 ...................Epoch 16/2033/33 [==============================] - 3s 95ms/step - loss: 0.0035 - val_loss: 0.0149Epoch 17/2033/33 [==============================] - 4s 111ms/step - loss: 0.0028 - val_loss: 0.0111Epoch 18/2033/33 [==============================] - 4s 110ms/step - loss: 0.0029 - val_loss: 0.0061Epoch 19/2033/33 [==============================] - 3s 104ms/step - loss: 0.0027 - val_loss: 0.0110Epoch 20/2033/33 [==============================] - 3s 90ms/step - loss: 0.0028 - val_loss: 0.0037Model: "sequential"_________________________________________________________________Layer (type) Output Shape Param # =================================================================simple_rnn (SimpleRNN) (None, 60, 80) 6560 _________________________________________________________________dropout (Dropout) (None, 60, 80) 0 _________________________________________________________________simple_rnn_1 (SimpleRNN) (None, 80) 12880 _________________________________________________________________dropout_1 (Dropout) (None, 80) 0 _________________________________________________________________dense (Dense) (None, 1) 81 =================================================================Total params: 19,521Trainable params: 19,521Non-trainable params: 0_________________________________________________________________
八、结果可视化
1.绘制loss图
plt.plot(history.history['loss'] , label='Training Loss')plt.plot(history.history['val_loss'], label='Validation Loss')plt.title('Training and Validation Loss by K同学啊')plt.legend()plt.show()
2.预测
predicted_stock_price = model.predict(x_test) # 测试集输入模型进行预测predicted_stock_price = sc.inverse_transform(predicted_stock_price) # 对预测数据还原---从(0,1)反归一化到原始范围real_stock_price = sc.inverse_transform(test_set[60:]) # 对真实数据还原---从(0,1)反归一化到原始范围# 画出真实数据和预测数据的对比曲线plt.plot(real_stock_price, color='red', label='Stock Price')plt.plot(predicted_stock_price, color='blue', label='Predicted Stock Price')plt.title('Stock Price Prediction by K同学啊')plt.xlabel('Time')plt.ylabel('Stock Price')plt.legend()plt.show()
3.评估
"""MSE :均方误差 -----> 预测值减真实值求平方后求均值RMSE :均方根误差 -----> 对均方误差开方MAE :平均绝对误差-----> 预测值减真实值求绝对值后求均值R2 :决定系数,可以简单理解为反映模型拟合优度的重要的统计量详细介绍可以参考文章:https://blog.csdn.net/qq_38251616/article/details/107997435"""MSE = metrics.mean_squared_error(predicted_stock_price, real_stock_price)RMSE = metrics.mean_squared_error(predicted_stock_price, real_stock_price)**0.5MAE = metrics.mean_absolute_error(predicted_stock_price, real_stock_price)R2 = metrics.r2_score(predicted_stock_price, real_stock_price)print('均方误差: %.5f' % MSE)print('均方根误差: %.5f' % RMSE)print('平均绝对误差: %.5f' % MAE)print('R2: %.5f' % R2)
均方误差: 1833.92534均方根误差: 42.82435平均绝对误差: 36.23424R2: 0.72347
往期精彩内容:
本文部分代码参考北京大学曹健教授的【人工智能实践:Tensorflow笔记】中的相关代码
来自专栏:
- ✨微信交流群:加我微信(mtyjkh_)拉你进群,不懂就问,备注:CSDN+来意
- ✨微信众号(K同学啊)后台回复【DL+天数】可获取《深度学习100例》的数据
转载地址:https://mtyjkh.blog.csdn.net/article/details/117752046 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!
发表评论
最新留言
路过按个爪印,很不错,赞一个!
[***.219.124.196]2024年04月23日 09时04分14秒
关于作者
喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
使用WebDriver完成web页面切换操作(附源码)
2019-05-03
WebDriver自定义显示等待条件
2019-05-03
在eclipse中显示GC情况
2019-05-03
webDriver自定义浏览器打开的等待时间
2019-05-03
webDriver中的几种timeout
2019-05-03
使用webdriver中的JavascriptExecutor执行js改变DOM属性
2019-05-03
WebDriver中close()与quit()的不同
2019-05-03
解决webdriver(Element not found in the cache - perhaps the page has changed since it was looked up )
2019-05-03
LR 杂记--loadrunner录制回放常见问题总结
2019-05-03
tcp 高性能服务, netty,mqtt
2019-05-03
排期模板
2019-05-03
物联网安全设计
2019-05-03
调研 中央空调 地暖 水暖
2019-05-03
谈创新和效率,如何总结分享. 归类,几大类
2019-05-03
架构图(拓扑图)画图工具分析整理(静态,动态,可交互图.层级tu)
2019-05-03
test 博客园功能 和 搜索 seo 能力
2019-05-03
待学习
2019-05-03
山东科技大学2015-2016学年第一学期程序设计基础期末考试第一场 题解
2019-05-03
Python教程-----引用模块
2019-05-03