博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
卷积神经网络补充—GoogleNet
阅读量:3938 次
发布时间:2019-05-23

本文共 3553 字,大约阅读时间需要 11 分钟。

复杂神经网络的问题

简单神经网络

我们注意这样的一个问题,我们在之前的学习当中使用的都是简单的一条龙走下来的方式进行学习,这种是比较基础的,没有分叉,没有循环,就是一条路走完。可以看到之前学的都是特别简单的串行结构。

在这里插入图片描述

GoogleNet

是一种基础架构,我们一般拿这个网络做一个主干网络,之后再在主干网络上进行修改之后作为我们实际应用的场景。

在这里插入图片描述

减少代码冗余:Inception Module

减少代码冗余看起来比较陌生其实我们可能已经默默应用了很久了,例如:在面向过程的编程语言当中使用函数,在面向过程的编程语言当中使用类。都是对代码进行减少冗余。

可以看到这种网络结构是十分的复杂,那么我们如果全部一点一点的写完,那么这个代码的复杂度就太大了。所以我们必须选择减少代码的冗余:
在这里插入图片描述
可以看到其实代码当中可复用的部分还是很多的,如上图:
这个块就被叫做:Inception(和电影盗梦空间重名)具体如下:
在这里插入图片描述
这里的设计灵感来自于,在模式识别的过程当中,主要是一些超参数的选择不好处理,例如我们在进行卷积的时候,我们该选择卷积核的维数为多少的问题。这样我们设计为多个不同的路径达到结果,这样子在不断的学习过程中,适合的卷积核所在的路径就会被逐渐的凸显出来。自动帮我们选择一个超参数。
那么这里

1×1的卷积到底是为了什么?

也就是这个图里的1×1卷积到底是在做什么?

在这里插入图片描述
如果我们这个channel数是一个1那么什么也别说了,这不就是一个矩阵数乘吗?
在这里插入图片描述
但是我们大部分实际应用的过程中channel并不是1而是多个,例如三个:
在这里插入图片描述
通过这个图我们可以看到,其实卷积是一个对不同通道数据的一个融合,融合这个说法比较空,不如举一个例子来理解一下:我们上高中的时候经常有个模考,或者月考。如果有六个科目,那么就是一个六维空间,在一个六维空间很难比较两个同学谁学习的情况更好,所以学校就有一个信息融合的算法,就是加总分。

当然这里是融合为一个通道,如果我们多几组卷积核,我们可以融合出目标个通道数的数据形式。这个过程中一个显著的特点就是改变了channel的个数。

可是改变一个channel的个数又有什么作用呢?我们可以详细看下图:
在这里插入图片描述
从图里我们可以看到在经过1×1卷积channel下降之后再进行5×5卷积可以有效地节约运算次数,这就是在节约钱啊。

Concatenate是什么?

也就是这个东西是什么?

在这里插入图片描述
这个东西其实是将之前的内容拆分,再在这一层结束之后从新合并在一起,我们注意一下这里的合并是在什么情况下合并的,这里的合并是依据channel合并的,也就是channel之下的全部内容都必须全部相等,才能顺利合并。
所以我们在中间过程的卷积过程当中要注意保证weight和height的稳定不变。

实现Googlenet

代码实现

import torchimport numpyfrom torchvision import functional as F#这里我们注意不论我们输入的通道数是多少,我们输出的通道数都是88.class Inception_Model(torch.nn.Module):    def __init__(self,input_channel):        super(Inception_Model,self).__init__()        self.poolingbranch= torch.nn.Conv2d(input_channel,24,kernel_size=1)        self.1m1branch =torch.nn.Conv2d(input_channel,16,kernel_size=1)        self.5m5branch1 =torch.nn.Conv2d(input_channel,16,kernel_size=1)        self.5m5branch2=torch.nn.Conv2d(16,24,kernel_size=5,padding=2)#注意这里设置一个2的padding是为了保证输出的图形的形状一致        self.3m3branch1=torch.nn.Conv2d(input_channel,16,kernel_size=1)        self.3m3branch2=torch.nn.Conv2d(16,24,kernel_size=3,padding=1)#同样这里的padding也是为了保证输出的图像一致        self.3m3branch3 =torch.nn.Conv2d(24,24,kernel_size=3,padding=1)    def forward(self,x):        branchpool=F.avg_pool2d(x,kernel_size=3,stride=1,padding=1)#这里是使用一个函数,因为这个求平均没有参数,所以直接可以当成函数使用        branchpool=poolingbranch(branchpool)        branch1m1=self.1m1branch(x)        branch5m5=self.5m5branch1(x)        branch5m5=self.5m5branch2(branch5m5)        branch3m3=self.3m3branch1(x)        branch3m3=self.3m3branch2(branch3m3)        branch5m5=self.3m3branch3(branch3m3)        outputs = [branchpool,branch1m1,branch5m5,branch3m3]#注意这里我们只是将其放在一个元组当中,并没有拼起来        return torch.cat(outputs,dim=1)#正着开始的第一号维度其实是第二个嘛,不就是channel吗。只有当只有channel不一样的时候,才可以顺利合并起来。class Net(torch.nn.Module):    def __init__(self):        super(Net,self).__init__()        self.conv1=torch.nn.Conv2d(1,10,kernel_size=5)        self.conv2=torch.nn.Conv2d(88,20,kernel_size=5)        self.incip1=Inception_Model(10)        self.incip2=Inception_Model(20)        self.mp=torch.nn.MaxPool2d(2)        self.fullconnect=nn=(1408,10)#这里我们注意,我们的1408是可以算出来的。                                     #但是实际上并没有人去算,都是输入一个数据测试一下到底是多少的情况。    def forward(self,x):        in_size=x.size(0)#注意这里我们其实是为了获得我们是在使用的数据集的大小的问题        x=F.relu(self.mp(self.conv1(x)))#从(1,28,28)到(10,24,24)再到(10,12,12)        x=self.incip1(x)#从(10,12,12)到(88,12,12)        x=f.relu(self.mp(self.conv2(x)))#从(88,12,12)到(10,8,8)再到(10,4,4)        x=self.incip2(x)#从(10,4,4)到(88,4,4)所以一个输出转化为一个一维张量就成了1408了        x=x.view(in_size,-1)#注意这里我们是将数据集中的,每一个数据,原来的矩阵换成一个行。        x=self.fullconnect(x)        return x

在MNIST数据集上的表现

在这里插入图片描述

在这里插入图片描述
你看这里其实经过几次训练之后,这个测试集的准确度就开始下降了,这时候,可能就是出现了过拟合的现象。已经达到了这种网络的极限。理论上我们需要画图来找到最大值。但是一般我们在实际的操作当中我们是每次达到新的高度就进行一次存盘。最后我们得到的就是一个最好的模型了。

转载地址:http://dkywi.baihongyu.com/

你可能感兴趣的文章
iptables学习
查看>>
fsck命令使用详解
查看>>
kvm快速安装部署
查看>>
apache三种工作模式及相关配置
查看>>
Apache与Nginx的优缺点比较
查看>>
select和epoll对比
查看>>
几种常见负载均衡比较
查看>>
虚拟网络
查看>>
Apache练习题
查看>>
sql常用命令
查看>>
CloudStack云基础架构的一些概念
查看>>
在centos7里安装zabbix3.4
查看>>
cloudstack搭建
查看>>
docker-compose使用
查看>>
springboot多个项目部署在tomcat服务器上的shiro的session污染问题
查看>>
mysql插入数据避免重复(Replace,IGNORE,on duplicate key update)
查看>>
mysql索引选择及优化
查看>>
MySQL数据类型、选择与优化
查看>>
Springboot系列(一)同属性名多对象处理
查看>>
mysql主从同步(复制)canal跨机房同步
查看>>