为了账号安全,请及时绑定邮箱和手机立即绑定

手写神经网络

标签:
人工智能

模式识别课上老师对神经网络做了详细的数学推导,回来自己又推了两遍,不难,也就是链式法则和求偏导,觉得没什么问题。但是总感觉缺了点什么。根据以前学习数字图像处理的经验:有些算法原理看着简单,实则只有自己造个轮子,才能真正掌握。于是决定手写一个简单的两层神经网络,实现二分类。

%author:harry
%date:2017.03.21

%将数据用矩阵的形式进行运算
% 使用多层感知器进行二分类
clear all
clc

X = zeros(200,2); 
randn('seed',0);
t1 = linspace(1*4,(1+1)*4,100) + rand(1,100)*0.2 ;
for i =  1:100
    X(i,:) = [i*sin(t1(i))/20,i*cos(t1(i))/20]; 
end
t2 = linspace(2*4,(2+1)*4,100) + rand(1,100)*0.2 ;
for j =  1:100
    X(100+j,:) = [j*sin(t2(j))/20,j*cos(t2(j))/20]; 
end

D = [ones(100,1);zeros(100,1)];%标签
data = X ;


% 构建神经网络
% 初始化权重参数及各层偏导
w1 = 2*rand(2,90);
w2 = 2*rand(90,1);
%w1 = 0.5*ones(2,90);
%w2 = 0.5*ones(90,1);

delta_w1 = zeros(90,2);
delta_w2 = zeros(90,1);

ecoh = 0 ; % 迭代次数
while(1)
    ecoh = ecoh + 1 ;
    fprintf('第%d次迭代\n',ecoh);
  %前向传播
  l0 = data ;
  l1 = sigmoid(l0*w1) ;
  l2 = sigmoid(l1*w2) ;
  l2_error = l2 - D ;
  if(rem(ecoh,100))
     disp(mean(abs(l2_error)));
  end
  l2_delta = l2_error .* desigmoid(l2) ;
  l1_error = l2_delta * w2' ;
  l1_delta = l1_error .* desigmoid(l1) ;
  w2 = w2 - l1'*l2_delta ;
  w1 = w1 - l0'*l1_delta ;

  if ecoh >= 10000  
    break ;
  end
end

figure(1)  
%plot(X1,Y1,'ro',X2,Y2,'bo');%画出两类样本点 
plot(X(1:100,1),X(1:100,2),'ro',X(101:200,1),X(101:200,2),'bo'); 
hold on;grid;   
x = (-10:0.01:10);
y = (-10:0.01:10);

[xx,yy] = meshgrid(x,y);

[row,col] = size(xx);
z = zeros(row,col);
for i = 1:row
    for j = 1:col
        l0 = [xx(i,j),yy(i,j)] ;
        l1 = sigmoid(l0*w1) ;
        l2 = sigmoid(l1*w2) ;
        if l2>=0.5
            z(i,j) = 1 ;
        else 
            z(i,j) = 0 ;
        end
    end
end
contour(xx,yy,z);

% 以下为测试部分      
for i=1:10  
  [x,y]=ginput(1);  
   plot(x,y,'m*');  
   sample=[x,y];  
    hold all  
    l0 = [x,y];
    l1 = sigmoid(l0*w1) ;
    l2 = sigmoid(l1*w2) ;
    if(l2 > 0.5)  
        disp('此点属于第一类');  
    else  
        disp('此点属于第二类');  
    end 
end

两类不同的颜色表示不同的样本。下图是神经网络经过1000次迭代后生成的分界线。
这里写图片描述
然后利用学到的模型(其实就是一堆参数),对任意输入的十个点做分类:
这里写图片描述

这里只是简单的介绍了神经网络的基本功能,还有很多深入的坑有待以后讲解。比如训练过程中是否出现梯度消失?网络是否过拟合?… …

原文出处

点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消