为了账号安全,请及时绑定邮箱和手机立即绑定

1.5 神经网络入门-神经元实现

标签:
深度学习

1.5 神经元实现

  • 分拆数据集

    def load_data(filename):
        """read data from data file."""
        with open(filename, 'rb') as f:
            data = pickle.load(f, encoding='bytes')        return data[b'data'], data[b'labels']# trensorflow.DataSetclass CifarData:
        def __init__(self, filenames, need_shuffle):
            all_data = []
            all_labels = []        for filename in filenames:
                data,labels = load_data(filename)            for item,label in zip(data,labels):                if label in [0,1]:
                        all_data.append(item)
                        all_labels.append(label)
            self._data = np.vstack(all_data)        # 归一化,将0-255的数归一成0-1直接的数
            self._data = self._data / 127.5 - 1 
            self._labels = np.hstack(all_labels)
            self._num_examples = self._data.shape[0]
            self._need_shuffle = need_shuffle
            self._indicator = 0
            if self._need_shuffle:
                self._shuffle_data()        
        def _shuffle_data(self):
            # 混排 [0,1,2,3,4,5] -> [2,1,4,0,3,5]
            p = np.random.permutation(self._num_examples)
            self._data = self._data[p]
            self._labels = self._labels[p]    
        def next_batch(self, batch_size):
            """return batch_size examples as a batch."""
            end_indicator = self._indicator + batch_size        if end_indicator > self._num_examples:            if self._need_shuffle:
                    self._shuffle_data()
                    self._indicator = 0
                    end_indicator = batch_size            else:                raise Exception("have no more examples")        if end_indicator > self._num_examples:            raise Exception("batch size is lager then all examples")
            batch_data = self._data[self._indicator:end_indicator]
            batch_labels = self._labels[self._indicator:end_indicator]
            self._indicator = end_indicator        return batch_data, batch_labels
            
    train_filename = [os.path.join(CIFAR_DIR,'data_batch_%d' % i) for i in range(1,6)]
    test_filenames = [os.path.join(CIFAR_DIR, 'test_batch')]
    
    train_data = CifarData(train_filename, True)
    test_data = CifarData(test_filenames, False)
    
    batch_data,batch_labels = train_data.next_batch(10)
  • 测试算法准确率

    init = tf.global_variables_initializer()
    batch_size = 20train_steps = 100000test_steps = 100with tf.Session() as sess:
        sess.run(init)    for i in range(train_steps):
            batch_data, batch_labels = train_data.next_batch(batch_size)
            loss_val, acc_val, _ = sess.run(
                [loss, accuracy, train_op],
                feed_dict={
                    x: batch_data,
                    y: batch_labels})        if (i+1) % 500 == 0:            print ('[Train] Step: %d, loss: %4.5f, acc: %4.5f' \
                    % (i+1, loss_val, acc_val))                
            if (i+1) % 5000 == 0:
                test_data = CifarData(test_filenames, False)
                all_test_acc_val = []            for j in range(test_steps):
                    test_batch_data, test_batch_labels \
                        = test_data.next_batch(batch_size)
                    test_acc_val = sess.run(
                        [accuracy],
                        feed_dict = {
                            x: test_batch_data, 
                            y: test_batch_labels
                        })
                    all_test_acc_val.append(test_acc_val)
                test_acc = np.mean(all_test_acc_val)
                print('[Test ] Step: %d, acc: %4.5f' % (i+1, test_acc))



作者:Meet相识_bfa5
链接:https://www.jianshu.com/p/fed88fcd3428


点击查看更多内容
TA 点赞

若觉得本文不错,就分享一下吧!

评论

作者其他优质文章

正在加载中
JAVA开发工程师
手记
粉丝
205
获赞与收藏
1008

关注作者,订阅最新文章

阅读免费教程

  • 推荐
  • 评论
  • 收藏
  • 共同学习,写下你的评论
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦
今天注册有机会得

100积分直接送

付费专栏免费学

大额优惠券免费领

立即参与 放弃机会
意见反馈 帮助中心 APP下载
官方微信

举报

0/150
提交
取消