Tensorflow入门——卷积神经网络MNIST手写数字识别

作者 : 开心源码 本文共3509个字,预计阅读时间需要9分钟 发布时间: 2022-05-12 共123人阅读

Image source: unsplash.com by Pawe? Czerwiński
之前的文章我们详情了如何用单层和多的全连接层神经网络识别手写数字,尽管识别率能够达到98%,但是因为全链接神经网络本身的局限性,其识别率已经很难再往上提升了。我们需要改进神经网络的结构,采用卷积神经网络(CNN)的结构来进一步提高的识别率。

关于CNN的原理,我在之前的文章中已经详情,这篇文章就不过多赘述,我们直接进入实战阶段。

同样的,为了方便与读者交流,所有的代码都放在了这里:

Repository:

zht007/tensorflow-practice

1. 初始化W和B

卷积神经网络中权重W的shape尤其重要,CNN中的W实际上就是一个四维的filter,这个四维的filter由n个三维filter堆叠而成,n的大小等于输出channel的深度。当然三维filter又是由m个二维filter堆叠的,m的大小等于输入Channel的深度。

动画效果可以参见这里。

W的shape为[filter[0], filter[1], input_channel_depth, output_channel_depth]

例如W[6,6,3,4] 表示:二维的filter的size是(6,6), 输入的图片有3个Channel, 输出的图片有4个Channel

偏置B的Shape与output_channel保持一致就可,tensorflow会自动broadcast成正确的维度,B在这里与多层神经网络的的初始化相同。

神经网络的结构一共5层,3层CNN,2层全链接,最后一层与单层神经网络一样,10个神经元输出识别结果:数字是是0-9的概率。

# three convolutional layers with their channel counts, and a# fully connected layer (the last layer has 10 softmax neurons)K = 12  # first convolutional layer output depthL = 24  # second convolutional layer output depthM = 48  # third convolutional layerN = 200  # fully connected layerW1 = tf.Variable(tf.truncated_normal([6,6,1,K], stddev=0.1)) B1 = tf.Variable(tf.ones([K])/10)W2 = tf.Variable(tf.truncated_normal([5,5,K,L], stddev=0.1))B2 = tf.Variable(tf.ones([L])/10)W3 = tf.Variable(tf.truncated_normal([4,4,L,M], stddev=0.1))B3 = tf.Variable(tf.ones([M])/10)W4 = tf.Variable(tf.truncated_normal([7*7*M,N], stddev=0.1))B4 = tf.Variable(tf.ones([N])/10)W5 = tf.Variable(tf.truncated_normal([N, 10], stddev=0.1))B5 = tf.Variable(tf.zeros([10]))

该部分代码部分参考[2][3] with Apache License 2.0

2. 神经网络搭建

CNN的部分,我们用tensorflow自带的tf.nn.conv2d()方法:

tf.nn.conv2d(    input,    filter,    strides,    padding,    use_cudnn_on_gpu=True,    data_format='NHWC',    dilations=[1, 1, 1, 1],    name=None)

用Tensorflow搭建神经网络的时候注意以下几点:

  1. Padding 这里使用的是’SAME’,也就是步长(stride)为1的时候输入与输出图片的shape保持一致。
  2. 这里没有使用Max-Pooling层来”压缩”图片,而是添加stride(第二层和第三层Stride 为2)的方式,效果是一样的。28×28的图片经过两层CNN之后,压缩成了14×14和7×7的图片。
  3. CNN与全连接神经网络连接之前,需要将CNN输出的图片拆开拼接成一维的向量(Flatten or Reshape)。
Y1 = tf.nn.relu(tf.nn.conv2d(X, W1, strides = [1,1,1,1], padding='SAME') + B1)Y2 = tf.nn.relu(tf.nn.conv2d(Y1,W2, strides = [1,2,2,1], padding='SAME') + B2)Y3 = tf.nn.relu(tf.nn.conv2d(Y2,W3, strides = [1,2,2,1], padding='SAME') + B3)#flat the inputs for the fully connected nnYY3 = tf.reshape(Y3, shape = (-1,7*7*M))                Y4 = tf.nn.relu(tf.matmul(YY3, W4) + B4)Y4d = tf.nn.dropout(Y4,rate = drop_rate)Ylogits = tf.matmul(Y4d, W5) + B5Y = tf.nn.softmax(Ylogits)

该部分代码部分参考[2][3] with Apache License 2.0

3. 识别效果

在其余参数都没改变的情况下,仅仅改变了神经网络的结构,可以看出识别率已经超出99%了。

image-20190402143453530

目前我通过CNN的神经网络训练出来的分类器参与Kaggle的比赛,最好成绩是识别率99.3,全球排名第792名。

image-20190402144054498

4. CNN结构的Keras实现

假如用Keras这个高级的API搭建CNN就更加简单了,无需初始化W和B,只要要关心神经网络的结构本身就行了。

使用Keras的layers.Conv2D()方法,注意其中的参数filters 是输出Channel的depth,Kernel_size 是二维filter的shape,实现相同结构的代码如下:

model = models.Sequential()model.add(layers.Conv2D(filters = 12, kernel_size=(6,6), strides=(1,1),                       padding = 'same', activation = 'relu',                       input_shape = (28,28,1)))          model.add(layers.Conv2D(filters = 24,kernel_size=(5,5),strides=(2,2),                       padding = 'same', activation = 'relu'))model.add(layers.Conv2D(filters = 48,kernel_size=(4,4),strides=(2,2),                       padding = 'same', activation = 'relu'))          model.add(layers.Flatten())                    model.add(layers.Dense(units=200, activation='relu'))model.add(layers.Dropout(0.25))model.add(layers.Dense(units=10, activation='softmax'))

参考资料

[1]https://www.kaggle.com/c/digit-recognizer/data

[2]https://codelabs.developers.google.com/codelabs/cloud-tensorflow-mnist/#0

[3] GoogleCloudPlatform/tensorflow-without-a-phd.git

[4]https://www.tensorflow.org/api_docs/


相关文章

Tensorflow入门——单层神经网络识别MNIST手写数字

Tensorflow入门——多层神经网络MNIST手写数字识别

AI学习笔记——Tensorflow中的Optimizer

Tensorflow入门——分类问题cross_entropy的选择

AI学习笔记——Tensorflow入门

Tensorflow入门——Keras简介和上手


同步到我的Steemit

https://steemit.com/@hongtao

说明
1. 本站所有资源来源于用户上传和网络,如有侵权请邮件联系站长!
2. 分享目的仅供大家学习和交流,您必须在下载后24小时内删除!
3. 不得使用于非法商业用途,不得违反国家法律。否则后果自负!
4. 本站提供的源码、模板、插件等等其他资源,都不包含技术服务请大家谅解!
5. 如有链接无法下载、失效或广告,请联系管理员处理!
6. 本站资源售价只是摆设,本站源码仅提供给会员学习使用!
7. 如遇到加密压缩包,请使用360解压,如遇到无法解压的请联系管理员
开心源码网 » Tensorflow入门——卷积神经网络MNIST手写数字识别

发表回复