pytorch打印自定义网络的每层的名称
发布日期:2021-07-01 04:36:55 浏览次数:2 分类:技术文章

本文共 1022 字,大约阅读时间需要 3 分钟。

pytorch打印自定义网络的每层的名称

import torchfrom torchvision import modelsfrom torchsummary import summaryfrom resnext_MulTask_clothes import resnext50_elasticdata_class=[8, 7]device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# vgg = models.vgg16().to(device)model = resnext50_elastic(num_classes=data_class) # 原模型model = torch.nn.DataParallel(model).cuda() # 并行处理# 已训练好的模型的pth文件checkpoint = torch.load('06-resnext50_elastic_checkpoint.pth.tar')model.load_state_dict(checkpoint['state_dict'], strict=False) # 参数加载summary(model, (3, 224, 224))

参考连接:https://www.jianshu.com/p/97c626d33924

另:

打印resnet152网络的每层的名称

import torchfrom torchvision import modelsfrom torchsummary import summaryfrom resnet_pretrained import resnet152device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = resnet152() # 原模型model = torch.nn.DataParallel(model).cuda() # 并行处理# 已训练好的模型的pth文件checkpoint = torch.load('resnet152-b121ed2d.pth')model.load_state_dict(checkpoint, strict=False) # 参数加载summary(model, (3, 224, 224))

转载地址:https://mymuli.blog.csdn.net/article/details/100834714 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!

上一篇:pytorch里面报错raise TypeError('tensor is not a torch image.')
下一篇:python将嵌套数组转为单层数组

发表评论

最新留言

做的很好,不错不错
[***.243.131.199]2024年04月13日 05时39分09秒