ResNet是何恺明等人于2015年提出的神经网络结构,该网络凭借其优秀的性能夺得了多项机器视觉领域竞赛的冠军,而后在2016年发表的论文《Deep Residual Learning for Image Recognition》也获得了CVPR2016最佳论文奖。本文整理了笔者对ResNet的理解,详细解释了ResNet34、ResNet50等具体结构,并使用PyTorch实现了一个使用ResNet训练CIFAR-10数据集的具体实例。
卷积神经网络在图像分类领域具有非常广泛的应用,从理论上讲,越深的网络结构,其拟合能力应该越强,如16层的VGG16的拟合能力要强于5层的LeNet。然而,何恺明等人通过实验发现,当网络的深度达到一定程度时,网络的性能不升反降,并且这种性能的下降不是由过拟合引起的,因为深度网络的训练误差和测试误差都比浅层网络高,如图1所示。
为改善上述问题,何恺明等人提出了深度残差学习框架。
2.1 残差学习
常规的神经网络,是将卷积层、全连接层等结构按照一定的顺序简单地连接到一起,每层结构仅接受来自上一层的信息,并在本层处理后传递给下一层。当网络层次加深后,这种单一的连接方式会导致神经网络性能退化。所谓的残差学习,就是在上述的单一连接方式的基础上,加入了“短连接”(shortcut connections),如图所示。
短连接能跨越几个层,将输入x直接映射到输出端(类似于电路中的短路),与输出相加。这样做导致的直接结果就是,在神经网络中加入上述的结构不再会导致神经网络的退化,考虑最坏的情况也不过是F(x)为0,这相当于网络没有加深,与之前一致。但是如果加入的层有效,就能使网络的性能得到提升。从数学的角度来看,上述结构的输入为x,输出H(x)为:
[H(x) = F(x)+x ]
其中F(x)称为残差函数,变换上式可得:
[F(x) = H(x)-x ]
显然,该结构中加入的短连接是没有参数的,因此学习H(x)的参数与学习F(x)的参数本质上是一样的,但是将函数优化为0或者在0附近显然更容易。因此,残差学习,学习的就是残差函数F(x)的参数。
2.2 ResNet基本单元
论文原文中给出了两种ResNet的基本单元结构,其中左边的单元用于较浅的网络,如ResNet18、ResNet34,右边的单元则用于较深的网络,如ResNet50、ResNet101等。
首先,无论是左边还是右边的单元,都是由卷积层和激活函数构成的。左边的单元由2个大小为3x3的卷积层与2个ReLU激活函数构成,右边的单元由2个1x1、1个3x3的卷积层和3个ReLU激活函数构成。至于卷积核的的通道(channel)数,则由单元的具体位置决定。通常卷积神经网络在提取图像特征的过程种,卷积核都会从“大卷积核、低通道数量“层层递进到到”小卷积核、多通道数量“。因此,在实际应用的过程种,短连接跨越的卷积层的通道数量有所改变是非常常见的。原论文中给出了直接连接、填零连接和映射连接三种具体的短连接形式。当实际单元的输入通道数与输出相同时,如上图左,则短连接直接连接即可。若输入通道数与输出不相同,如上图右,输入的维度为64,输出却是256,此时输入x无法与F(x)直接相加。填零连接就是将多出的维度全部填充0后再进行相加,这个方法不会引入额外的参数。映射连接则是通过矩阵变换,利用大小为1x1的卷积层扩展输入的维度后再进行相加。短连接跨越通道数不同的卷积层时,卷积执行的步长为2,这刚好缩小特征矩阵的大小,并增加了特征的通道数。具体的细节将在下文中的代码实现,此处不再赘述。
2.3 ResNet
原论文中给出了如下5个具体的ResNet网络,这些网络都由上述的基本单元构成。
为了便于阐述,我们直接来讨论ResNet34,该网络主要由16个基本单元、1个7x7卷积层与1个全连接层构成,共计34层。
其中,实线的短连接表示直接连接,前后通道数未发生改变。虚线的短连接则表示前后维度发生了改变,需要通过填0或者映射后进行连接。
上面讲了这么多白开水般的原理,相信各位看官早已疲倦,接下来上紧张刺激的代码,毕竟看了那么多,最终目的还是为了应用。对于ResNet34,其基本单元实现代码如下。
只要理解了上文所叙述的原理,代码还是很容易理解的。其中的扩展系数expansion表示的是单元输出与输入张量的通道数之比,对于ResNet34,这个比是1。而对于ResNet50,这个比是4。如果无法理解,可以回顾上文2.2中的图。ResNet50的基本单元代码如下。
有了建房子的”砖块“,就可以着手建房了,ResNet的实现代码如下。
该代码可以实现论文中的5个网络ResNet18、ResNet34、ResNet50、ResNet101、ResNet152。ResNet34的实例化代码如下:
具体的参数num_blocks可查阅2.3中的表确定。
搭建好了神经网络,就可以用于训练模型了。本文代码参考了Gihub上的项目pytorch-cifar与torchvision.models中的代码,特此声明,以示感谢。
本文介绍了一些我在学习过程中对于ResNet的理解,并给出了代码与具体的实例。从实质上讲,ResNet与传统卷积神经网络的主要区别就是引入了短连接,而这个短连接极大地提升了神经网络的性能。至于数学原理,本文并未深究,毕竟包括我在内的许多学习者都是关心如何应用,而不是深挖那些数学公式的来龙去脉。
由于本人才疏学浅,如有疏漏谬误之处,还请指出,万分感谢。
到此这篇resnet原文(resnet作者)的文章就介绍到这了,更多相关内容请继续浏览下面的相关推荐文章,希望大家都能在编程的领域有一番成就!
版权声明:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若内容造成侵权、违法违规、事实不符,请将相关资料发送至xkadmin@xkablog.com进行投诉反馈,一经查实,立即处理!
转载请注明出处,原文链接:https://www.xkablog.com/rfx/36794.html