Point cleannet训练代码解析
发布日期:2022-02-28 07:22:43 浏览次数:35 分类:技术文章

本文共 16218 字,大约阅读时间需要 54 分钟。

在这里插入代码片from __future__ import print_functionimport argparseimport osimport sysimport randomimport mathimport shutilimport torchimport torch.nn.parallelimport torch.optim as optimimport torch.optim.lr_scheduler as lr_schedulerimport torch.utils.datafrom torch.autograd import Variablefrom tensorboardX import SummaryWriterfrom dataset import PointcloudPatchDataset, RandomPointcloudPatchSampler, SequentialShapeRandomPointcloudPatchSamplerfrom pcpnet import ResPCPNetdef parse_arguments():    parser = argparse.ArgumentParser()    # naming / file handling    parser.add_argument(        '--name', type=str, default='PoinCleanNetOutliers', help='training run name')    parser.add_argument(        '--desc', type=str, default='My training run for single-scale normal estimation.', help='description')    parser.add_argument('--indir', type=str, default='../data/pointCleanNetOutliersTrainingSet',                        help='input folder (point clouds)')    parser.add_argument('--outdir', type=str, default='../models',                        help='output folder (trained models)')    parser.add_argument('--logdir', type=str,                        default='./logs', help='training log folder')    parser.add_argument('--trainset', type=str,                        default='trainingset.txt', help='training set file name')    parser.add_argument('--testset', type=str,                        default='validationset.txt', help='test set file name')    parser.add_argument('--saveinterval', type=int,                        default='10', help='save model each n epochs')    parser.add_argument('--refine', type=str, default='',                        help='refine model at this path')    # training parameters    parser.add_argument('--nepoch', type=int, default=2000,                        help='number of epochs to train for')    parser.add_argument('--batchSize', type=int,                        default=64, help='input batch size')    parser.add_argument('--patch_radius', type=float, default=[                        0.05], nargs='+', help='patch radius in multiples of the shape\'s bounding box diagonal, multiple values for multi-scale.')    parser.add_argument('--patch_center', type=str, default='point', help='center patch at...\n'                        'point: center point\n'                        'mean: patch mean')    parser.add_argument('--patch_point_count_std', type=float, default=0,                        help='standard deviation of the number of points in a patch')    parser.add_argument('--patches_per_shape', type=int, default=100,                        help='number of patches sampled from each shape in an epoch')    parser.add_argument('--workers', type=int, default=1,                        help='number of data loading workers - 0 means same thread as main execution')    parser.add_argument('--cache_capacity', type=int, default=600,                        help='Max. number of dataset elements (usually shapes) to hold in the cache at the same time.')    parser.add_argument('--seed', type=int,                        default=3627473, help='manual seed')    parser.add_argument('--training_order', type=str, default='random', help='order in which the training patches are presented:\n'                        'random: fully random over the entire dataset (the set of all patches is permuted)\n'                        'random_shape_consecutive: random over the entire dataset, but patches of a shape remain consecutive (shapes and patches inside a shape are permuted)')    parser.add_argument('--identical_epochs', type=int, default=False,                        help='use same patches in each epoch, mainly for debugging')    parser.add_argument('--lr', type=float, default=0.0001,                        help='learning rate')    parser.add_argument('--momentum', type=float, default=0.9,                        help='gradient descent momentum')    parser.add_argument('--use_pca', type=int, default=False,                        help='Give both inputs and ground truth in local PCA coordinate frame')    # model hyperparameters    parser.add_argument('--outputs', type=str, nargs='+', default=['outliers'], help='outputs of the network')    parser.add_argument('--use_point_stn', type=int,                        default=True, help='use point spatial transformer')    parser.add_argument('--use_feat_stn', type=int,                        default=True, help='use feature spatial transformer')    parser.add_argument('--sym_op', type=str, default='max',                        help='symmetry operation')    parser.add_argument('--point_tuple', type=int, default=1,                        help='use n-tuples of points as input instead of single points')    parser.add_argument('--points_per_patch', type=int,                        default=500, help='max. number of points per patch')    return parser.parse_args()def check_path_existance(log_dirname, model_filename, opt):    if os.path.exists(log_dirname) or os.path.exists(model_filename):        if os.path.exists(log_dirname):            shutil.rmtree(os.path.join(opt.logdir, opt.name))def get_output_format(opt):    # get indices in targets and predictions corresponding to each output    target_features = []    output_target_ind = []    output_pred_ind = []    output_loss_weight = []    pred_dim = 0    for o in opt.outputs:        if o in ['unoriented_normals', 'oriented_normals']:            if 'normal' not in target_features:                target_features.append('normal')            output_target_ind.append(target_features.index('normal'))            output_pred_ind.append(pred_dim)            output_loss_weight.append(1.0)            pred_dim += 3        elif o in ['max_curvature', 'min_curvature']:            if o not in target_features:                target_features.append(o)            output_target_ind.append(target_features.index(o))            output_pred_ind.append(pred_dim)            if o == 'max_curvature':                output_loss_weight.append(0.7)            else:                output_loss_weight.append(0.3)            pred_dim += 1        elif o in ['clean_points']:            target_features.append(o)            output_target_ind.append(target_features.index(o))            output_pred_ind.append(pred_dim)            output_loss_weight.append(1.0)            pred_dim += 3        elif o in ['outliers']:            target_features.append(o)            output_target_ind.append(target_features.index(o))            output_pred_ind.append(pred_dim)            output_loss_weight.append(1.0)            pred_dim += 1        else:            raise ValueError('Unknown output: %s' % (o))    if pred_dim <= 0:        raise ValueError('Prediction is empty for the given outputs.')    return target_features, output_target_ind, output_pred_ind,  output_loss_weight, pred_dimdef get_data(target_features, opt, train=True):    # create train and test dataset loaders    if train:        shapes_list_file = opt.trainset    else:        shapes_list_file = opt.testset    dataset = PointcloudPatchDataset(        root=opt.indir,        shapes_list_file=shapes_list_file,        patch_radius=opt.patch_radius,        points_per_patch=opt.points_per_patch,        patch_features=target_features,        point_count_std=opt.patch_point_count_std,        seed=opt.seed,        identical_epochs=opt.identical_epochs,        use_pca=opt.use_pca,        center=opt.patch_center,        point_tuple=opt.point_tuple,        cache_capacity=opt.cache_capacity)    print('training_order ', opt.training_order)    if opt.training_order == 'random':        datasampler = RandomPointcloudPatchSampler(            dataset,            patches_per_shape=opt.patches_per_shape,            seed=opt.seed,            identical_epochs=opt.identical_epochs)    elif opt.training_order == 'random_shape_consecutive':        datasampler = SequentialShapeRandomPointcloudPatchSampler(            dataset,            patches_per_shape=opt.patches_per_shape,            seed=opt.seed,            identical_epochs=opt.identical_epochs)    else:        raise ValueError('Unknown training order: %s' % (opt.training_order))    dataloader = torch.utils.data.DataLoader(        dataset,        sampler=datasampler,        batch_size=opt.batchSize,        num_workers=int(opt.workers))    return dataloader, datasampler, dataseth'h'h'h'h'h'h'h'h'h'h'h'h'h'h'h'h'h'hhhhhh'h'h'h'h'h'h'h'h'h'h'h'h'h'h'h'h'h'h    def green(x): return '\033[92m' + x + '\033[0m'    def blue(x): return '\033[94m' + x + '\033[0m'    log_dirname = os.path.join(opt.logdir, opt.name)    params_filename = os.path.join(opt.outdir, '%s_params.pth' % (opt.name))    model_filename = os.path.join(opt.outdir, '%s_model.pth' % (opt.name))    desc_filename = os.path.join(opt.outdir, '%s_description.txt' % (opt.name))    check_path_existance(log_dirname, model_filename, opt)    target_features, output_target_ind, output_pred_ind, output_loss_weight, n_predicted_features = get_output_format(opt)    pcpnet = ResPCPNet(        num_points=opt.points_per_patch,        output_dim=n_predicted_features,        use_point_stn=opt.use_point_stn,        use_feat_stn=opt.use_feat_stn,        sym_op=opt.sym_op,        point_tuple=opt.point_tuple)    if opt.refine != '':        pcpnet.load_state_dict(torch.load(opt.refine))    if opt.seed < 0:        opt.seed = random.randint(1, 10000)生成一个1-10000的随机整数    print("Random Seed: %d" % (opt.seed))    random.seed(opt.seed)    torch.manual_seed(opt.seed)为CPU中设置种子,生成随机数    # create train and test dataset loaders    train_dataloader, train_datasampler, train_dataset = get_data(target_features, opt, train=True)获取数据(训练集)    test_dataloader, test_datasampler, test_dataset = get_data(target_features, opt, train=False)测试集    # keep the exact training shape names for later reference    opt.train_shapes = train_dataset.shape_names训练情况    opt.test_shapes = test_dataset.shape_names测试情况    print('training set: %d patches (in %d batches) - test set: %d patches (in %d batches)' %          (len(train_datasampler), len(train_dataloader), len(test_datasampler), len(test_dataloader)))    try:        os.makedirs(opt.outdir)    except OSError:        pass    train_writer = SummaryWriter(os.path.join(log_dirname, 'train'))    test_writer = SummaryWriter(os.path.join(log_dirname, 'test'))    optimizer = optim.SGD(pcpnet.parameters(), lr=opt.lr,                          momentum=opt.momentum)优化器,初始化学习率    # milestones in number of optimizer iterations    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[], gamma=0.1)学习速率调整,每经历一次学习率×0.1    pcpnet.cuda()    total_train_batches = len(train_dataloader)    total_test_batches = len(test_dataloader)    # save parameters    torch.save(opt, params_filename)    # save description    with open(desc_filename, 'w+') as text_file:        print(opt.desc, file=text_file)    criterion = torch.nn.L1Loss()定义L1LOSS    for epoch in range(opt.nepoch):        current_train_batch_index = -1        train_completion = 0.0        train_batches = enumerate(train_dataloader, 0)枚举训练patch        current_test_batch_index = -1        test_completion = 0.0        test_batches = enumerate(test_dataloader, 0)枚举测试patch        for current_train_batch_index, data in train_batches:            # update learning rate            scheduler.step(epoch * total_train_batches + current_train_batch_index)更新学习率(在epoch这个循环里)            # set to training mode            pcpnet.train()            # get trainingset batch, convert to variables and upload to GPU            points = data[0]            target = data[1:-1]            points = Variable(points)            points = points.transpose(2, 1)            points = points.cuda()            target = tuple(Variable(t) for t in target)            target = tuple(t.cuda() for t in target)            # zero gradients            optimizer.zero_grad()梯度初始化为0            # forward pass            pred, trans, _, _ = pcpnet(points)前向传播            loss= compute_loss(pred=pred, target=target,outputs=opt.outputs,output_pred_ind=output_pred_ind,                output_target_ind=output_target_ind,output_loss_weight=output_loss_weight,                patch_rot=trans if opt.use_point_stn else None, criterion=criterion)计算损失函数            # backpropagate through entire network to compute gradients of loss w.r.t. parameters            loss.backward()反向传播            # parameter optimization step            optimizer.step()更新参数            train_completion = (current_train_batch_index + 1) / total_train_batches 训练完成度--当前索引+1/总patch            # print info and update log file            print('[%s %d/%d: %d/%d] %s loss: %f' % (opt.name, epoch, opt.nepoch, current_train_batch_index,                                                  total_train_batches - 1, green('train'), loss.item()))            train_writer.add_scalar('loss', loss.item(),                                    (epoch + train_completion) * total_train_batches * opt.batchSize)            while test_completion <= train_completion and current_test_batch_index + 1 < total_test_batches:                # set to evaluation mode                pcpnet.eval()                current_test_batch_index, data = next(test_batches)                # get testset batch, convert to variables and upload to GPU                # volatile means that autograd is turned off for everything that depends on the volatile variable                # since we dont need autograd for inference (only for training)                points = data[0]                target = data[1:-1]                points = Variable(points, volatile=True)                points = points.transpose(2, 1)                points = points.cuda()                target = tuple(Variable(t, volatile=True) for t in target)                target = tuple(t.cuda() for t in target)                # forward pass                pred, trans, _, _ = pcpnet(points)                loss = compute_loss(                    pred=pred, target=target,                    outputs=opt.outputs,                    output_pred_ind=output_pred_ind,                    output_target_ind=output_target_ind,                    output_loss_weight=output_loss_weight,                    patch_rot=trans if opt.use_point_stn else None,                    criterion=criterion)                test_completion = (current_test_batch_index + 1) / total_test_batches                # print info and update log file                print('[%s %d: %d/%d] %s loss: %f' % (opt.name, epoch,                                                      current_train_batch_index, total_train_batches - 1, blue('test'), loss.item()))                # print('min normal len: %f' % (pred.data.norm(2,1).min()))                test_writer.add_scalar(                    'loss', loss.item(), (epoch + test_completion) *total_train_batches * opt.batchSize)        # save model, overwriting the old model        if epoch % opt.saveinterval == 0 or epoch == opt.nepoch - 1:            torch.save(pcpnet.state_dict(), model_filename)        # save model in a separate file in epochs 0,5,10,50,100,500,1000, ...        if epoch % (5 * 10**math.floor(math.log10(max(2, epoch - 1)))) == 0 or epoch % 100 == 0 or epoch == opt.nepoch - 1:            torch.save(pcpnet.state_dict(), os.path.join(                opt.outdir, '%s_model_%d.pth' % (opt.name, epoch)))def compute_accuracy(pred, target, l1_loss):    pred[pred >0.5]= 1    pred[pred<=0.5] = 0    return l1_loss(pred, target)def compute_outliers_loss(pred, output_pred_ind, output_index,  target, output_target_ind, output_loss_weight, criterion):计算离群LOSS    o_pred = pred[:, output_pred_ind[output_index]:output_pred_ind[output_index] + 1]    o_target = target[output_target_ind[output_index]]    o_target = o_target.cuda()    loss = criterion(o_pred[:,0], o_target)    return lossdef compute_loss(pred, target, outputs, output_pred_ind, output_target_ind, output_loss_weight, patch_rot, criterion = None, cleaning=False):计算损失函数    loss = 0    for output_index, output in enumerate(outputs):枚举输出        loss += compute_outliers_loss(pred, output_pred_ind, output_index,  target, output_target_ind, output_loss_weight, criterion)    return lossif __name__ == '__main__':    train_opt = parse_arguments()    train_pcpnet(train_opt)

转载地址:https://blog.csdn.net/weixin_45854106/article/details/108216816 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!

上一篇:Leetcode 1046.最后一块石头的重量(JS版)
下一篇:PointAugment:点云分类的自动增强框架

发表评论

最新留言

初次前来,多多关照!
[***.217.46.12]2024年04月09日 10时27分40秒