这个小节呢就是见证或者说学习神经网络是如何通过简单的形式将一群数据用一条线来表示,也就是说,在一堆数据中如何找到他们之间的关系,然后用神经网络模型来建立一个可以代表他们关系的线条。至于机器学习中分类器和回归的概念呢,这里就不做阐述,可以参见CSDN链接链接描述。
- torch中的变量Variable
在开始之前,有必要说明一下什么是variable。在torch中的variable就是一个存放会变化的值的地理位置,里面的值会不停的变化,好比一个书架存放一些书,书架不会动,但是书架上的书的数目是会一直变化的。如果用一个 Variable 进行计算, 那返回的也是一个同类型的 Variable。下面举例说明一下:
import torch
from torch.autograd import Variable # torch 中 Variable 模块
# 先买书
tensor = torch.FloatTensor([[1,2],[3,4]])
# 把书放到书架里
variable = Variable(tensor)
print(variable) # Variable 形式
"""
Variable containing:
1 2
3 4
[torch.FloatTensor of size 2x2]
"""
print(variable.data) # tensor 形式
"""
1 2
3 4
[torch.FloatTensor of size 2x2]
"""
print(variable.data.numpy()) # numpy 形式
"""
[[ 1. 2.]
[ 3. 4.]]
"""
- 关系拟合(回归)
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import torch.nn.functional as tf
#建立数据集
#创建假数据来模拟真实的情况,这里是一个一元二次函数:y = a*x^2 + b,同时给y数据加一点噪声来更真实的展示
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(2) + 0.2*torch.rand(x.size())
#用variable来修饰这些tensor数据
x, y = torch.autograd.Variable(x), Variable(y)
#可以先画图瞄一眼
#plt.scatter(x.data.numpy(), y.data.numpy())
#plt.show()
#建立神经网络
class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
# 定义每层用什么样的形式
self.hidden = torch.nn.Linear(n_feature, n_hidden)
self.predict = torch.nn.Linear(n_hidden, n_output)
def forward(self, x):
# 正向传播输入值, 神经网络分析出输出值
x = tf.relu(self.hidden(x)) # 激励函数(隐藏层的线性值)
x = self.predict(x) # 输出值
return x
net = Net(n_feature=1, n_hidden=10, n_output=1)
print (net)
#训练网络
optimizer = torch.optim.SGD(net.parameters(), lr = 0.5) # optimizer 是训练的工具
loss_func = torch.nn.MSELoss() # 预测值和真实值的误差计算公式 (均方差)
#可视化训练过程
plt.ion()
for t in range(100):
prediction = net(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if t % 5 == 0:
plt.cla()
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw = 5)
plt.text(0.5, 0, 'loss = %.4f' % loss.data[0], fontdict={'size':20, 'color': 'red'})
plt.pause(0.1)
plt.ioff()
plt.show()
创建数据
回归结果
点击查看更多内容
为 TA 点赞
评论
共同学习,写下你的评论
评论加载中...
作者其他优质文章
正在加载中
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦