论如何做到轻量级网络(Unet为例)
发布日期:2021-06-29 12:25:51
浏览次数:2
分类:技术文章
本文共 5653 字,大约阅读时间需要 18 分钟。
先贴一张整个过程中参数的下降量:
可以看到,最后参数量为原始的20%左右。
查看网络参数量的代码:
# 网络参数数量def get_parameter_number(net): total_num = sum(p.numel() for p in net.parameters()) trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad) return {'Total': total_num, 'Trainable': trainable_num}
我首先写了一个基本的Unet网络:
import torch.nn as nnimport torch.nn.functional as Fimport torchdef down_sample(in_channel, channels): return nn.Sequential( nn.Conv2d(in_channel, channels, kernel_size=3, padding=0, stride=1), nn.BatchNorm2d(channels), nn.ReLU(), nn.Conv2d(channels, channels, kernel_size=3, padding=0, stride=1), nn.BatchNorm2d(channels), nn.ReLU(), )def up_sample(in_channle): return nn.Sequential( nn.ConvTranspose2d(in_channle, in_channle//2, kernel_size=2, stride=2), )def up_conv(in_channel): return nn.Sequential( nn.Conv2d(in_channel, in_channel//2, kernel_size=3, padding=0, stride=1), nn.BatchNorm2d(in_channel//2), nn.ReLU(), nn.Conv2d(in_channel//2, in_channel//2, kernel_size=3, padding=0, stride=1), nn.BatchNorm2d(in_channel//2), nn.ReLU() )def cat(x1, x2): diff = x1.size()[2] - x2.size()[2] x2 = F.pad(x2, [diff // 2, diff - diff // 2, diff // 2, diff - diff // 2]) x2 = torch.cat([x1, x2], dim=1) return x2class Unet(nn.Module): def __init__(self, in_channel, out_channel): super(Unet, self).__init__() self.down_sample1 = down_sample(in_channel, 64) self.down_sample2 = down_sample(64, 128) self.down_sample3 = down_sample(128, 256) self.down_sample4 = down_sample(256, 512) self.down_sample5 = down_sample(512, 1024) self.pooling = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.up_sample1 = up_sample(1024) self.up_sample2 = up_sample(512) self.up_sample3 = up_sample(256) self.up_sample4 = up_sample(128) self.up_conv1 = up_conv(1024) self.up_conv2 = up_conv(512) self.up_conv3 = up_conv(256) self.up_conv4 = up_conv(128) self.out_conv = nn.Sequential( nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=0), nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(64, 64, kernel_size=3, padding=0, stride=1), nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(64, 2, kernel_size=3, padding=0, stride=1) ) def forward(self, x): x1 = self.down_sample1(x) p1 = self.pooling(x1) x2 = self.down_sample2(p1) p2 = self.pooling(x2) x3 = self.down_sample3(p2) p3 = self.pooling(x3) x4 = self.down_sample4(p3) p4 = self.pooling(x4) x5 = self.down_sample5(p4) # torch.Size([1, 1024, 32, 32]) x6 = self.up_sample1(x5) x6 = cat(x4, x6) x7 = self.up_conv1(x6) x7 = self.up_sample2(x7) x7 = cat(x3, x7) x8 = self.up_conv2(x7) x8 = self.up_sample3(x8) x8 = cat(x2, x8) x9 = self.up_conv3(x8) x9 = self.up_sample4(x9) x9 = cat(x1, x9) output = self.out_conv(x9) return output
然后我测试了一下参数量为:
{'Total': 31155586, 'Trainable': 31155586}
然后呢,我将下采样过程中的普通卷积采用了深度可分离卷积,也就是改变了上面的down_sample函数:
def down_sample(in_channel, channels): return nn.Sequential( nn.Conv2d(in_channel, in_channel, kernel_size=3), nn.BatchNorm2d(in_channel), nn.ReLU(), nn.Conv2d(in_channel, channels, kernel_size=1, padding=0, stride=1), nn.BatchNorm2d(channels), nn.ReLU(), nn.Conv2d(channels, channels, kernel_size=3, padding=0, stride=1), nn.BatchNorm2d(channels), nn.ReLU(), )
接着查看了一下网络参数:
{'Total': 28719900, 'Trainable': 28719900}
可以看到,网络参数明显降低。
接着呢,我又将上采样的普通卷积转化为了深度可分离卷积,也就是改变了up_conv函数:
def up_conv(in_channel): return nn.Sequential( nn.Conv2d(in_channel, in_channel // 2, kernel_size=1, padding=0, stride=1), nn.BatchNorm2d(in_channel // 2), nn.ReLU(), nn.Conv2d(in_channel // 2, in_channel//2, kernel_size=3, padding=0, stride=1), nn.BatchNorm2d(in_channel//2), nn.ReLU(), nn.Conv2d(in_channel//2, in_channel//2, kernel_size=3, padding=0, stride=1), nn.BatchNorm2d(in_channel//2), nn.ReLU() )
然后查看参数:
{'Total': 26285660, 'Trainable': 26285660}
可以看到,又降低了。
因为我们知道,Unet网络的卷积下采样过程对特征图的大小影响不大,所以网络参数降低的不多。
然后呢?我采用group4为4的分组卷积,再次修改了上面两个函数:
def down_sample(in_channel, channels): return nn.Sequential( nn.Conv2d(in_channel, in_channel, kernel_size=3), nn.BatchNorm2d(in_channel), nn.ReLU(), nn.Conv2d(in_channel, channels, kernel_size=1, padding=0, stride=1), nn.BatchNorm2d(channels), nn.ReLU(), nn.Conv2d(channels, channels, kernel_size=3, padding=0, stride=1, groups=4), nn.BatchNorm2d(channels), nn.ReLU(), )
def up_conv(in_channel): return nn.Sequential( nn.Conv2d(in_channel, in_channel // 2, kernel_size=1, padding=0, stride=1, groups=4), nn.BatchNorm2d(in_channel // 2), nn.ReLU(), nn.Conv2d(in_channel // 2, in_channel//2, kernel_size=3, padding=0, stride=1, groups=4), nn.BatchNorm2d(in_channel//2), nn.ReLU(), nn.Conv2d(in_channel//2, in_channel//2, kernel_size=3, padding=0, stride=1, groups=4), nn.BatchNorm2d(in_channel//2), nn.ReLU() )
然后呢?可以看到,此时参数量为:
{'Total': 11635292, 'Trainable': 11635292}
参数量已经比开始的时候降低了3倍多了。
接着呢,我采用了Inception v3中的结构,也就是将其中的3x3卷积换成了1x3和3x1卷积,查看参数:
{'Total': 9035594, 'Trainable': 9035594}
又降低了。
最后,我将所有的普通卷积转化为了depthwise卷积,也就是groups = inchannel,观察参数:
{'Total': 5919434, 'Trainable': 5919434}
转载地址:https://bupt-xbz.blog.csdn.net/article/details/105278072 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!
发表评论
最新留言
第一次来,支持一个
[***.219.124.196]2024年04月03日 19时39分52秒
关于作者
喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
力扣的删除排序链表中的重复元素解法 (Python3)
2019-04-29
力扣的环形链表解法 (Python)
2019-04-29
力扣的盛最多水的容器解法 (Python)
2019-04-29
力扣的电话号码的字母组合解法(Python)
2019-04-29
力扣的组合总和解法 (Python)
2019-04-29
力扣的两数相加解法 (Python)
2019-04-29
力扣的删除链表的倒数第N个节点解法(Python)
2019-04-29
力扣的串联所有单词的子串解法(Python)
2019-04-29
力扣的接雨水解法(Python3)
2019-04-29
HTML5 五种密码框
2019-04-29
Node.js npm uuid
2019-04-29
JavaScript 滑动验证
2019-04-29
CSS3 二级菜单
2019-04-29
CSS3 帧动画(Sprite,直译叫雪碧图)
2019-04-29
JavaScript 帧动画
2019-04-29
Java NIO —— 用 Path 取代 File
2019-04-29
毕业后的五年拉开大家差距的原因在哪里?
2019-04-29
Java Callable、Future、FutureTask
2019-04-29
Java 父线程与子线程相互通信的方法
2019-04-29
Java 逃逸分析
2019-04-29