采用CNN(卷积神经网络)模型进行mnist分类任务。
本文的完整代码托管在我的Github PnYuan - Practice-of-Machine-Learning - MNIST_tensorflow_demo,欢迎交流。
1.任务背景
这里,我们拟通过搭建卷积神经网络(CNN)来完成MNIST手写数字识别任务,关于MNIST任务的相关内容可参考前文深度学习基础 - MNIST实验(tensorflow+Softmax)或深度学习基础 - MNIST实验(tensorflow+MLP)。
2.实验过程
实验参考代码:python + tensorflow: cnn_demo.py & cnn_demo_self_test.py
实验分三步进行:
- 参考LeNet-5,搭建适用于该任务的CNN模型,开发实现基于tensorflow;
- 加载MNIST数据集,配置超参数,进行训练与测试,分析效果;
- 加载自制手写图片,采用训练好的CNN进行识别,分析效果;
2.1.CNN建模
LeNet-5是Y.LeCun等人早期所设计的一种CNN,是经典的神经网络架构之一,如下图所示:(参考原文献)
本实验采用python-tensorflow实现LeNet-5,其建模代码样例如下:
1 | '''construction of leNet-5 model''' |
2.2.训练与测试
设置优化策略及相关超参数(如learning_rate
、num_epochs
、mini-batch size
等),进行训练,经过一段时间的训练,得出的accuracy
结果如下:
Train Accuracy: 0.9920
Valid Accuracy: 0.9896
Test Accuracy: 0.9881
同时该训练期间,指标accuracy
和cost
的变化过程如下图示:
可以看出,此处CNN(LeNet-5)已经取得了不错的结果(≈99%的测试准确率)。而通过观察训练曲线变化趋势,猜测随着迭代的继续,模型效果还可继续提升。
2.3.实测
接下来验证该CNN模型在生活场景下的泛化效果,笔者此处在实验室即兴写了若干待识别数字,示意如下:
采用之前所训练的CNN,得出预测结果示意如下:
结果中出现了一些识别错误,初步猜测是由数据分布的差异所引起。可以考虑在图像训练和测试时,先采用更多的预处理手段(如灰度归一化、对比度增强、阈值分割…),从而使分布接近。降低模型迁移难度。
3.实验小结
本文采用CNN模型进行mnist手写数字识别任务,取得了很好的效果(99%的测试准确率)。同时采用训练好的模型识别了实际场景中的数字,体现了一定的识别效果。
4.参考资料
官方参考:
CNN模型:
开发辅助: