Point cleannet代码解析
发布日期:2022-02-28 07:22:43
浏览次数:36
分类:技术文章
本文共 20175 字,大约阅读时间需要 67 分钟。
PCPNET部分
from __future__ import print_functionimport numpy as np import torchimport torch.nn as nnimport torch.nn.parallelimport torch.utils.dataimport torch.nn.functional as Ffrom torch.autograd import Variableimport utilsclass STN(nn.Module): 定义STN旋转网络 def __init__(self, num_scales=1, num_points=500, dim=3, sym_op='max', quaternion =False): //设置输入的总点数为500,维度为3,四元组 为FALSE super(STN, self).__init__()//初始化 self.quaternion = quaternion self.dim = dim self.sym_op = sym_op self.num_scales = num_scales self.num_points = num_points self.conv1 = torch.nn.Conv1d(self.dim, 64, 1) 定义一维卷积层(1×1)卷积(这里的作用是将原始的点进行升维3->64)个人理解是把每一个点看作一组一维向量所以才是一维卷积。 self.conv2 = torch.nn.Conv1d(64, 128, 1)同上(64->128) self.conv3 = torch.nn.Conv1d(128, 1024, 1)(128->1024) self.mp1 = torch.nn.MaxPool1d(num_points)一维池化 self.fc1 = nn.Linear(1024, 512)全连接层转化到512 self.fc2 = nn.Linear(512, 256) 全连接层转化到256 if not quaternion: self.fc3 = nn.Linear(256, self.dim*self.dim)如果没有旋转器就输出dim平方个分数 else: self.fc3 = nn.Linear(256, 4)这里的4是4元组 self.bn1 = nn.BatchNorm1d(64)这里使用了BN进行优化可以保持不同输入的分布一致性和中间隐藏层数据传递之间的分布一致性。 self.bn2 = nn.BatchNorm1d(128) self.bn3 = nn.BatchNorm1d(1024) self.bn4 = nn.BatchNorm1d(512) self.bn5 = nn.BatchNorm1d(256) if self.num_scales > 1:(如果有若干个区域则先全连接成一个区域) self.fc0 = nn.Linear(1024*self.num_scales, 1024) self.bn0 = nn.BatchNorm1d(1024) def forward(self, x): batchsize = x.size()[0] x = F.relu(self.bn1(self.conv1(x)))使用bn进行优化,然后使用RELU进行激活 x = F.relu(self.bn2(self.conv2(x))) x = F.relu(self.bn3(self.conv3(x)))此时升维到1024 # symmetric operation over all points 对称化 if self.num_scales == 1: 如果规模为1则直接池化降维 x = self.mp1(x) else: if x.is_cuda: x_scales = Variable(torch.cuda.FloatTensor(x.size(0), 1024*self.num_scales, 1))如果是多维的话构造flaot tensor取行数 else: x_scales = Variable(torch.FloatTensor(x.size(0), 1024*self.num_scales, 1)) for s in range(self.num_scales): x_scales[:, s*1024:(s+1)*1024, :] = self.mp1(x[:, :, s*self.num_points:(s+1)*self.num_points]) x = x_scales x = x.view(-1, 1024*self.num_scales) if self.num_scales > 1: x = F.relu(self.bn0(self.fc0(x))) x = F.relu(self.bn4(self.fc1(x))) x = F.relu(self.bn5(self.fc2(x))) x = self.fc3(x) if not self.quaternion: iden = Variable(torch.from_numpy(np.identity(self.dim, 'float32')).clone()).view(1, self.dim*self.dim).repeat(batchsize, 1) if x.is_cuda: iden = iden.cuda() x = x + iden x = x.view(-1, self.dim, self.dim) else: # add identity quaternion (so the network can output 0 to leave the point cloud identical) iden = Variable(torch.FloatTensor([1, 0, 0, 0])) if x.is_cuda: iden = iden.cuda() x = x + iden # convert quaternion to rotation matrix if x.is_cuda: trans = Variable(torch.cuda.FloatTensor(batchsize, 3, 3)) else: trans = Variable(torch.FloatTensor(batchsize, 3, 3)) x = utils.batch_quat_to_rotmat(x, trans) return xclass PointNetfeat(nn.Module):point net feat 网络 def __init__(self, num_scales=1, num_points=500, use_point_stn=True, use_feat_stn=True, sym_op='max', get_pointfvals=False, point_tuple=1): 其中use stn为第一旋转网络,feat stn定义为第二个旋转网络 对称函数使用的是max super(PointNetfeat, self).__init__() self.num_points = num_points self.num_scales = num_scales self.use_point_stn = use_point_stn self.use_feat_stn = use_feat_stn self.sym_op = sym_op self.get_pointfvals = get_pointfvals self.point_tuple = point_tuple if self.use_point_stn:如果使用了第一个旋转器 # self.stn1 = STN(num_scales=self.num_scales, num_points=num_points, dim=3, sym_op=self.sym_op) self.stn1 = STN(num_scales=self.num_scales, num_points=num_points*self.point_tuple, dim=3, sym_op=self.sym_op, quaternion = True) turple指的是元组中元素个数 if self.use_feat_stn:如果使用了第二个旋转器 self.stn2 = STN(num_scales=self.num_scales, num_points=num_points, dim=64, sym_op=self.sym_op) self.conv0a = torch.nn.Conv1d(3*self.point_tuple, 64, 1) self.conv0b = torch.nn.Conv1d(64, 64, 1)这里64-64进行1×1的卷积 # TODO remove # self.conv0c = torch.nn.Conv1d(64, 64, 1) # self.bn0c = nn.BatchNorm1d(64) # self.conv1b = torch.nn.Conv1d(64, 64, 1) # self.bn1b = nn.BatchNorm1d(64) self.bn0a = nn.BatchNorm1d(64) self.bn0b = nn.BatchNorm1d(64) self.conv1 = torch.nn.Conv1d(64, 64, 1) self.conv2 = torch.nn.Conv1d(64, 128, 1) self.conv3 = torch.nn.Conv1d(128, 1024, 1) self.bn1 = nn.BatchNorm1d(64) self.bn2 = nn.BatchNorm1d(128) self.bn3 = nn.BatchNorm1d(1024) 这里说明的是多种SCALE的情况 if self.num_scales > 1: self.conv4 = torch.nn.Conv1d(1024, 1024*self.num_scales, 1) self.bn4 = nn.BatchNorm1d(1024*self.num_scales) if self.sym_op == 'max':对称使用max self.mp1 = torch.nn.MaxPool1d(num_points) elif self.sym_op == 'sum':对称函数使用sum self.mp1 = None else: raise ValueError('Unsupported symmetric operation: %s' % (self.sym_op)) def forward(self, x): # input transform if self.use_point_stn: # from tuples to list of single points x = x.view(x.size(0), 3, -1) trans = self.stn1(x) x = x.transpose(2, 1)方便后面数据处理对坐标轴变换一下 x = torch.bmm(x, trans)将X与trans相乘 x = x.transpose(2, 1) x = x.contiguous().view(x.size(0), 3*self.point_tuple, -1) else: trans = None # mlp (64,64)多层感知机 x = F.relu(self.bn0a(self.conv0a(x))) x = F.relu(self.bn0b(self.conv0b(x))) # TODO remove #x = F.relu(self.bn0c(self.conv0c(x))) # feature transform if self.use_feat_stn: trans2 = self.stn2(x) x = x.transpose(2, 1) x = torch.bmm(x, trans2) x = x.transpose(2, 1) else: trans2 = None # mlp (64,128,1024) x = F.relu(self.bn1(self.conv1(x))) # TODO remove #x = F.relu(self.bn1b(self.conv1b(x))) x = F.relu(self.bn2(self.conv2(x))) x = self.bn3(self.conv3(x)) # mlp (1024,1024*num_scales) if self.num_scales > 1: x = self.bn4(self.conv4(F.relu(x))) if self.get_pointfvals: pointfvals = x else: pointfvals = None # so the intermediate result can be forgotten if it is not needed # symmetric max operation over all points if self.num_scales == 1: if self.sym_op == 'max': x = self.mp1(x) elif self.sym_op == 'sum': x = torch.sum(x, 2, keepdim=True)按行求和 else: raise ValueError('Unsupported symmetric operation: %s' % (self.sym_op)) else: if x.is_cuda: x_scales = Variable(torch.cuda.FloatTensor(x.size(0), 1024*self.num_scales**2, 1)) else: x_scales = Variable(torch.FloatTensor(x.size(0), 1024*self.num_scales**2, 1)) if self.sym_op == 'max': for s in range(self.num_scales): x_scales[:, s*self.num_scales*1024:(s+1)*self.num_scales*1024, :] = self.mp1(x[:, :, s*self.num_points:(s+1)*self.num_points]) elif self.sym_op == 'sum': for s in range(self.num_scales): x_scales[:, s*self.num_scales*1024:(s+1)*self.num_scales*1024, :] = torch.sum(x[:, :, s*self.num_points:(s+1)*self.num_points], 2, keepdim=True) else: raise ValueError('Unsupported symmetric operation: %s' % (self.sym_op)) x = x_scales x = x.view(-1, 1024*self.num_scales**2) return x, trans, trans2, pointfvalsclass BasicBlock(nn.Module): expansion = 1 def __init__(self, in_planes, planes, stride=1, conv = False): super(BasicBlock, self).__init__() if conv: self.l1 = torch.nn.Conv1d(in_planes, planes, 1) self.l2 = torch.nn.Conv1d(planes, planes, 1) else: self.l1 = nn.Linear(in_planes,planes) self.l2 = nn.Linear(planes, planes) stdv = 0.001 # for working small initialisation # self.l1.weight.data.uniform_(-stdv, stdv) self.l1.weight.data.uniform_(-stdv, stdv) self.l2.weight.data.uniform_(-stdv, stdv) self.l1.bias.data.uniform_(-stdv, stdv) self.l2.bias.data.uniform_(-stdv, stdv) self.bn1 = nn.BatchNorm1d(planes, momentum = 0.01) self.shortcut = nn.Sequential() if in_planes != planes: if conv: self.l0 = nn.Conv1d(in_planes, planes, 1) else: self.l0 = nn.Linear(in_planes, planes) self.l0.weight.data.uniform_(-stdv, stdv) self.l0.bias.data.uniform_(-stdv, stdv) self.shortcut = nn.Sequential(self.l0,nn.BatchNorm1d(planes)) self.bn2 = nn.BatchNorm1d(planes, momentum = 0.01) def forward(self, x): out = F.relu(self.bn1(self.l1(x))) out = self.bn2(self.l2(out)) out += self.shortcut(x) out = F.relu(out) return outclass ResSTN(nn.Module):此部分是STN def __init__(self, num_scales=1, num_points=500, dim=3, sym_op='max', quaternion =False): super(ResSTN, self).__init__() self.quaternion = quaternion self.dim = dim self.sym_op = sym_op self.num_scales = num_scales self.num_points = num_points self.b1 = BasicBlock(self.dim, 64, conv = True) self.b2 = BasicBlock(64, 128, conv = True) self.b3 = BasicBlock(128, 1024, conv = True) self.mp1 = torch.nn.MaxPool1d(num_points) self.bfc1 = BasicBlock(1024, 512) self.bfc2 = BasicBlock(512, 256) if not quaternion: self.bfc3 = BasicBlock(256, self.dim*self.dim) else: self.bfc3 = BasicBlock(256, 4) if self.num_scales > 1: self.bfc0 = BasicBlock(1024*self.num_scales, 1024) def forward(self, x): batchsize = x.size()[0] x = self.b1(x) x = self.b2(x) x = self.b3(x) # symmetric operation over all points if self.num_scales == 1: x = self.mp1(x) else: if x.is_cuda: x_scales = Variable(torch.cuda.FloatTensor(x.size(0), 1024*self.num_scales, 1)) else: x_scales = Variable(torch.FloatTensor(x.size(0), 1024*self.num_scales, 1)) for s in range(self.num_scales): x_scales[:, s*1024:(s+1)*1024, :] = self.mp1(x[:, :, s*self.num_points:(s+1)*self.num_points]) x = x_scales x = x.view(-1, 1024*self.num_scales) if self.num_scales > 1: x = self.bfc0(x) x =self.bfc1(x) x = self.bfc2(x) x = self.bfc3(x) if not self.quaternion: iden = Variable(torch.from_numpy(np.identity(self.dim, 'float32')).clone()).view(1, self.dim*self.dim).repeat(batchsize, 1) if x.is_cuda: iden = iden.cuda() x = x + iden x = x.view(-1, self.dim, self.dim) else: # add identity quaternion (so the network can output 0 to leave the point cloud identical) iden = Variable(torch.FloatTensor([1, 0, 0, 0])) if x.is_cuda: iden = iden.cuda() x = x + iden # convert quaternion to rotation matrix if x.is_cuda: trans = Variable(torch.cuda.FloatTensor(batchsize, 3, 3)) else: trans = Variable(torch.FloatTensor(batchsize, 3, 3)) x = utils.batch_quat_to_rotmat(x, trans) return xclass ResPointNetfeat(nn.Module): def __init__(self, num_scales=1, num_points=500, use_point_stn=True, use_feat_stn=True, sym_op='max', get_pointfvals=False, point_tuple=1): super(ResPointNetfeat, self).__init__() self.num_points = num_points self.num_scales = num_scales self.use_point_stn = use_point_stn self.use_feat_stn = use_feat_stn self.sym_op = sym_op self.get_pointfvals = get_pointfvals self.point_tuple = point_tuple if self.use_point_stn: # self.stn1 = STN(num_scales=self.num_scales, num_points=num_points, dim=3, sym_op=self.sym_op) self.stn1 = ResSTN(num_scales=self.num_scales, num_points=num_points*self.point_tuple, dim=3, sym_op=self.sym_op, quaternion=True) if self.use_feat_stn: self.stn2 = ResSTN(num_scales=self.num_scales, num_points=num_points, dim=64, sym_op=self.sym_op) self.b0a = BasicBlock(3*self.point_tuple, 64, conv = True) self.b0b = BasicBlock(64, 64, conv=True) self.b1 = BasicBlock(64, 64, conv = True) self.b2 = BasicBlock(64, 128, conv = True) self.b3 = BasicBlock(128, 1024, conv = True) if self.num_scales > 1: self.b4 = BasicBlock(1024, 1024*self.num_scales, conv = True) if self.sym_op == 'max': self.mp1 = torch.nn.MaxPool1d(num_points) elif self.sym_op == 'sum': self.mp1 = None else: raise ValueError('Unsupported symmetric operation: %s' % (self.sym_op)) def forward(self, x): # input transform if self.use_point_stn: # from tuples to list of single points x = x.view(x.size(0), 3, -1) trans = self.stn1(x) x = x.transpose(2, 1) x = torch.bmm(x, trans) x = x.transpose(2, 1) x = x.contiguous().view(x.size(0), 3*self.point_tuple, -1) else: trans = None # mlp (64,64) x = self.b0a(x) x = self.b0b(x) # feature transform if self.use_feat_stn: trans2 = self.stn2(x) x = x.transpose(2, 1) x = torch.bmm(x, trans2) x = x.transpose(2, 1) else: trans2 = None # mlp (64,128,1024) x = self.b1(x) x = self.b2(x) x = self.b3(x) # mlp (1024,1024*num_scales) if self.num_scales > 1: x = self.b4(x) if self.get_pointfvals: pointfvals = x else: pointfvals = None # so the intermediate result can be forgotten if it is not needed # symmetric max operation over all points if self.num_scales == 1: if self.sym_op == 'max': x = self.mp1(x) elif self.sym_op == 'sum': x = torch.sum(x, 2, keepdim=True) else: raise ValueError('Unsupported symmetric operation: %s' % (self.sym_op)) else: if x.is_cuda: x_scales = Variable(torch.cuda.FloatTensor(x.size(0), 1024*self.num_scales**2, 1)) else: x_scales = Variable(torch.FloatTensor(x.size(0), 1024*self.num_scales**2, 1)) if self.sym_op == 'max': for s in range(self.num_scales): x_scales[:, s*self.num_scales*1024:(s+1)*self.num_scales*1024, :] = self.mp1(x[:, :, s*self.num_points:(s+1)*self.num_points]) elif self.sym_op == 'sum': for s in range(self.num_scales): x_scales[:, s*self.num_scales*1024:(s+1)*self.num_scales*1024, :] = torch.sum(x[:, :, s*self.num_points:(s+1)*self.num_points], 2, keepdim=True) else: raise ValueError('Unsupported symmetric operation: %s' % (self.sym_op)) x = x_scales x = x.view(-1, 1024*self.num_scales**2) return x, trans, trans2, pointfvalsclass ResPCPNet(nn.Module): def __init__(self, num_points=500, output_dim=3, use_point_stn=True, use_feat_stn=True, sym_op='max', get_pointfvals=False, point_tuple=1): super(ResPCPNet, self).__init__() self.num_points = num_points self.feat = ResPointNetfeat( num_points=num_points, num_scales=1, use_point_stn=use_point_stn, use_feat_stn=use_feat_stn, sym_op=sym_op, get_pointfvals=get_pointfvals, point_tuple=point_tuple) self.b1 = BasicBlock(1024, 512) self.b2 = BasicBlock(512, 256) self.b3 = BasicBlock(256, output_dim) def forward(self, x): x, trans, trans2, pointfvals = self.feat(x) x = self.b1(x) x = self.b2(x) x = self.b3(x) return x, trans, trans2, pointfvalsclass ResMSPCPNet(nn.Module): def __init__(self, num_scales=2, num_points=500, output_dim=3, use_point_stn=True, use_feat_stn=True, sym_op='max', get_pointfvals=False, point_tuple=1): super(ResMSPCPNet, self).__init__() self.num_points = num_points self.feat = ResPointNetfeat( num_points=num_points, num_scales=num_scales, use_point_stn=use_point_stn, use_feat_stn=use_feat_stn, sym_op=sym_op, get_pointfvals=get_pointfvals, point_tuple=point_tuple) self.b0 = BasicBlock(1024*num_scales**2, 1024) self.b1 = BasicBlock(1024, 512) self.b2 = BasicBlock(512, 256) self.b3 = BasicBlock(256, output_dim) def forward(self, x): x, trans, trans2, pointfvals = self.feat(x) x = self.b0(x) x = self.b1(x) x = self.b2(x) x = self.b3(x) return x, trans, trans2, pointfvalsclass PCPNet(nn.Module):PCPNET网络 def __init__(self, num_points=500, output_dim=3, use_point_stn=True, use_feat_stn=True, sym_op='max', get_pointfvals=False, point_tuple=1): super(PCPNet, self).__init__() self.num_points = num_points self.feat = PointNetfeat(首先调用point net feat网络 num_points=num_points, num_scales=1, use_point_stn=use_point_stn, use_feat_stn=use_feat_stn, sym_op=sym_op, get_pointfvals=get_pointfvals, point_tuple=point_tuple) self.fc1 = nn.Linear(1024, 512) #self.fc_new = nn.Linear(512, 512) self.fc2 = nn.Linear(512, 256) self.fc3 = nn.Linear(256, output_dim) self.bn1 = nn.BatchNorm1d(512) #self.bn_new = nn.BatchNorm1d(512) self.bn2 = nn.BatchNorm1d(256) self.do1 = nn.Dropout(p=0.3) #self.do_new = nn.Dropout(p=0.3) self.do2 = nn.Dropout(p=0.3) def forward(self, x): x, trans, trans2, pointfvals = self.feat(x) x = F.relu(self.bn1(self.fc1(x))) x = self.do1(x) # x = F.relu(self.bn_new(self.fc_new(x))) #x = self.do_new(x) x = F.relu(self.bn2(self.fc2(x))) x = self.do2(x) x = self.fc3(x) return x, trans, trans2, pointfvalsclass MSPCPNet(nn.Module): def __init__(self, num_scales=2, num_points=500, output_dim=3, use_point_stn=True, use_feat_stn=True, sym_op='max', get_pointfvals=False, point_tuple=1): super(MSPCPNet, self).__init__() self.num_points = num_points self.feat = PointNetfeat( num_points=num_points, num_scales=num_scales, use_point_stn=use_point_stn, use_feat_stn=use_feat_stn, sym_op=sym_op, get_pointfvals=get_pointfvals, point_tuple=point_tuple) self.fc0 = nn.Linear(1024*num_scales**2, 1024) self.fc1 = nn.Linear(1024, 512) self.fc2 = nn.Linear(512, 256) self.fc3 = nn.Linear(256, output_dim) self.bn0 = nn.BatchNorm1d(1024) self.bn1 = nn.BatchNorm1d(512) self.bn2 = nn.BatchNorm1d(256) self.do1 = nn.Dropout(p=0.3) self.do2 = nn.Dropout(p=0.3) def forward(self, x): x, trans, trans2, pointfvals = self.feat(x) x = F.relu(self.bn0(self.fc0(x))) x = F.relu(self.bn1(self.fc1(x))) x = self.do1(x) x = F.relu(self.bn2(self.fc2(x))) x = self.do2(x) x = self.fc3(x) return x, trans, trans2, pointfvals
转载地址:https://blog.csdn.net/weixin_45854106/article/details/108071200 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!
发表评论
最新留言
不错!
[***.144.177.141]2024年04月01日 23时20分12秒
关于作者
喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
一个 Java 方法,最多能定义多少参数?
2019-04-27
我司用了 6 年的 Redis 分布式限流器,很牛逼了!
2019-04-27
10 分钟快速上手 Shiro 新手教程
2019-04-27
玩大发了,Tomcat 8.5 升级有坑…
2019-04-27
我把 Spring Boot Banner 换成了美女背景后……
2019-04-27
常用的 Git 命令,给你准备好了!
2019-04-27
Java虚拟机最多支持多少个线程?
2019-04-27
Linux 与 Unix 到底有啥区别和联系?
2019-04-27
带着问题学 Kubernetes 架构!
2019-04-27
Redis VS Memcached,到底怎么选?
2019-04-27
最新!Dubbo 远程代码执行漏洞通告,速度升级
2019-04-27
你还在使用 try-catch-finally 关闭资源?
2019-04-27
Tomcat 又爆出高危漏洞!!Tomcat 8.5 ~ 10 中招…
2019-04-27
分布式事务不理解?一次给你讲清楚!
2019-04-27
存储过程用还是不用?我有话说!
2019-04-27
牛逼哄哄的布隆过滤器,到底有什么用?
2019-04-27
Spring Boot & Restful API 构建实战!
2019-04-27
10 个牛逼的单行代码编程技巧,你会用吗?
2019-04-27
因为一个跨域请求,我差点丢了饭碗!
2019-04-27
最流行的 RESTful API 要怎么设计?
2019-04-27