pytorch 训练数据以及测试 全部代码(2)
发布日期:2021-06-29 11:44:35
浏览次数:3
分类:技术文章
本文共 4745 字,大约阅读时间需要 15 分钟。
p={‘trainBatch’:6, 'nAveGrad':1, 'lr':1e-07, 'wd':0.0005, 'momentum':0.9,'epoch_size':10, 'optimizer':'SGD()'}最后一个optimizer的值是很长的字符串就不全部写出来了。这个字典长度是7。
其中的net 和criterion在稍后来进行讲解
if resume_epoch==0,那么从头开始训练 training from scratch;否则权重的初始化时一个已经训练好的模型,使用net.load_state_dict函数,这个函数是在torch.nn.Module类里面定义的一个函数。
def load_state_dict(self, state_dict, strict=True): r"""Copies parameters and buffers from :attr:`state_dict` into this module and its descendants. If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Arguments: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` """ missing_keys = [] unexpected_keys = [] error_msgs = [] # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, '_metadata', None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata def load(module, prefix=''): module._load_from_state_dict( state_dict, prefix, strict, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') load(self)
而里面的torch.load函数定义如下.map_location参数有三种形式:函数,字符串,字典
def load(f, map_location=None, pickle_module=pickle): """Loads an object saved with :func:`torch.save` from a file. :meth:`torch.load` uses Python's unpickling facilities but treats storages, which underlie tensors, specially. They are first deserialized on the CPU and are then moved to the device they were saved from. If this fails (e.g. because the run time system doesn't have certain devices), an exception is raised. However, storages can be dynamically remapped to an alternative set of devices using the `map_location` argument. If `map_location` is a callable, it will be called once for each serialized storage with two arguments: storage and location. The storage argument will be the initial deserialization of the storage, residing on the CPU. Each serialized storage has a location tag associated with it which identifies the device it was saved from, and this tag is the second argument passed to map_location. The builtin location tags are `'cpu'` for CPU tensors and `'cuda:device_id'` (e.g. `'cuda:2'`) for CUDA tensors. `map_location` should return either None or a storage. If `map_location` returns a storage, it will be used as the final deserialized object, already moved to the right device. Otherwise, :math:`torch.load` will fall back to the default behavior, as if `map_location` wasn't specified. If `map_location` is a string, it should be a device tag, where all tensors should be loaded. Otherwise, if `map_location` is a dict, it will be used to remap location tags appearing in the file (keys), to ones that specify where to put the storages (values). User extensions can register their own location tags and tagging and deserialization methods using `register_package`. Args: f: a file-like object (has to implement read, readline, tell, and seek), or a string containing a file name map_location: a function, string or a dict specifying how to remap storage locations pickle_module: module used for unpickling metadata and objects (has to match the pickle_module used to serialize file) Example: >>> torch.load('tensors.pt') # Load all tensors onto the CPU >>> torch.load('tensors.pt', map_location='cpu') # Load all tensors onto the CPU, using a function >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage) # Load all tensors onto GPU 1 >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1)) # Map tensors from GPU 1 to GPU 0 >>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'}) # Load tensor from io.BytesIO object >>> with open('tensor.pt') as f: buffer = io.BytesIO(f.read()) >>> torch.load(buffer) """
设置使用GPU,这里是
torch.cuda.set_device(device=0) 告诉编码器cuda使用gpu0号
net.cuda() 将模型放在gpu0号上面
关于writer = SummaryWriter(log_dir=log_dir)这个函数在后面会讲解
num_img_tr = len(trainloader)# 1764num_img_ts = len(testloader)# 242 这是batch数目
转载地址:https://blog.csdn.net/zz2230633069/article/details/82868093 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!
发表评论
最新留言
不错!
[***.144.177.141]2024年04月04日 02时28分55秒
关于作者
喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
10年老兵!从大学毕业生到嵌入式系统工程师的修炼之道……
2019-04-29
如何才能学好单片机?
2019-04-29
一根网线有这么多“花样”,你知道吗?
2019-04-29
雷军1994年写的诗一样的代码,我把它运行起来了!
2019-04-29
2020年大学生电子设计竞赛,B题,单相在线式不间断电源,详细技术方案!
2019-04-29
大佬终于把鸿蒙OS讲明白了,收藏了!
2019-04-29
C语言指针,这可能是史上最干最全的讲解啦(附代码)!!!
2019-04-29
国内大陆有哪些芯片公司处于世界前10?一起看看!
2019-04-29
单精度、双精度、多精度和混合精度计算的区别是什么?
2019-04-29
中国35位“大国工匠”榜单出炉!西工大、西电合计占半壁江山!清华仅1人!...
2019-04-29
知乎热议:嵌入式开发中C++好用吗?
2019-04-29
2020,Python 已死?
2019-04-29
漫画:程序员相亲?哈哈哈哈哈哈
2019-04-29
30种EMC标准电路分享,再不收藏就晚了!
2019-04-29
这100道Linux常见面试题,看看你会多少?
2019-04-29
十年硬件老司机,结合实际案例,带你探索单片机低功耗设计!
2019-04-29
“2020年嵌入式软件秋招经验和对嵌入式软件未来的一点思考”
2019-04-29
嵌入式的坑在哪方面?
2019-04-29
三种常见嵌入式设备通信协议
2019-04-29
硬核,这个充电宝居然烧煤气!
2019-04-29