from torch.utils.data import DataLoader DataLoader类
发布日期:2021-06-29 11:44:33 浏览次数:3 分类:技术文章

本文共 5426 字,大约阅读时间需要 18 分钟。

from torch.utils.data import DataLoaderdataloader = DataLoader(dataset,batch_size=5,shuffle=True,num_workers=2) # 实例化

 参数dataset是一个数据集(这一点个人认为描述的很大)

batch_size默认是1,是一次性读取多少张图片,下面中称呼为采样器个数

shuffle默认是false不打乱顺序

sampler定义从数据集绘制样本的策略。如果指定了相应的策略那么shuffle必须是false

batch_sampler定义了一次性从数据集里面拿出来的数据,与 batch_size, shuffle, sampler, and drop_last是互斥的

num_worker 多少个子线程用于加载数据,默认是0,表示只在主线程加载数据

timeout一定是非负的数值

drop_last这个参数决定是否保留余数作为一个batch.举例:有图片13张,batch_size=4,那么整除得3余1,如果该参数值为False那么总共batch=4,如果为true那么总共的batch=3.

其他参数就不解释了

调用len()函数会直接使用里面的魔法方法,得到的是总的Batch数目

class DataLoader(object):    __initialized = False    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None,                 num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False,                 timeout=0, worker_init_fn=None):    r"""    Data loader. Combines a dataset and a sampler, and provides    single- or multi-process iterators over the dataset.    Arguments:        dataset (Dataset): dataset from which to load the data.        batch_size (int, optional): how many samples per batch to load            (default: 1).        shuffle (bool, optional): set to ``True`` to have the data reshuffled            at every epoch (default: False).        sampler (Sampler, optional): defines the strategy to draw samples from            the dataset. If specified, ``shuffle`` must be False.        batch_sampler (Sampler, optional): like sampler, but returns a batch of            indices at a time. Mutually exclusive with batch_size, shuffle,            sampler, and drop_last.        num_workers (int, optional): how many subprocesses to use for data            loading. 0 means that the data will be loaded in the main process.            (default: 0)        collate_fn (callable, optional): merges a list of samples to form a mini-batch.        pin_memory (bool, optional): If ``True``, the data loader will copy tensors            into CUDA pinned memory before returning them.        drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,            if the dataset size is not divisible by the batch size. If ``False`` and            the size of dataset is not divisible by the batch size, then the last batch            will be smaller. (default: False)        timeout (numeric, optional): if positive, the timeout value for collecting a batch            from workers. Should always be non-negative. (default: 0)        worker_init_fn (callable, optional): If not None, this will be called on each            worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as            input, after seeding and before data loading. (default: None)    .. note:: By default, each worker will have its PyTorch seed set to              ``base_seed + worker_id``, where ``base_seed`` is a long generated              by main process using its RNG. However, seeds for other libraies              may be duplicated upon initializing workers (w.g., NumPy), causing              each worker to return identical random numbers. (See              :ref:`dataloader-workers-random-seed` section in FAQ.) You may              use ``torch.initial_seed()`` to access the PyTorch seed for each              worker in :attr:`worker_init_fn`, and use it to set other seeds              before data loading.    .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an                 unpicklable object, e.g., a lambda function.    """        self.dataset = dataset        self.batch_size = batch_size        self.num_workers = num_workers        self.collate_fn = collate_fn        self.pin_memory = pin_memory        self.drop_last = drop_last        self.timeout = timeout        self.worker_init_fn = worker_init_fn        if timeout < 0:            raise ValueError('timeout option should be non-negative')        if batch_sampler is not None:            if batch_size > 1 or shuffle or sampler is not None or drop_last:                raise ValueError('batch_sampler option is mutually exclusive '                                 'with batch_size, shuffle, sampler, and '                                 'drop_last')            self.batch_size = None            self.drop_last = None        if sampler is not None and shuffle:            raise ValueError('sampler option is mutually exclusive with '                             'shuffle')        if self.num_workers < 0:            raise ValueError('num_workers option cannot be negative; '                             'use num_workers=0 to disable multiprocessing.')        if batch_sampler is None:            if sampler is None:                if shuffle:                    sampler = RandomSampler(dataset)                else:                    sampler = SequentialSampler(dataset)            batch_sampler = BatchSampler(sampler, batch_size, drop_last)        self.sampler = sampler        self.batch_sampler = batch_sampler        self.__initialized = True    def __setattr__(self, attr, val):        if self.__initialized and attr in ('batch_size', 'sampler', 'drop_last'):            raise ValueError('{} attribute should not be set after {} is '                             'initialized'.format(attr, self.__class__.__name__))        super(DataLoader, self).__setattr__(attr, val)    def __iter__(self):        return _DataLoaderIter(self)    def __len__(self):        return len(self.batch_sampler)

 

转载地址:https://blog.csdn.net/zz2230633069/article/details/82684425 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!

上一篇:scipy.misc.imread函数,读取图片
下一篇:关于类型为numpy,TensorFlow.tensor,torch.tensor的shape变化以及相互转化

发表评论

最新留言

路过,博主的博客真漂亮。。
[***.116.15.85]2024年04月30日 20时50分16秒