torch.utils.data¶
- 
class 
torch.utils.data.Dataset(*args, **kwds)[source]¶ An abstract class representing a
Dataset.All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite
__getitem__(), supporting fetching a data sample for a given key. Subclasses could also optionally overwrite__len__(), which is expected to return the size of the dataset by manySamplerimplementations and the default options ofDataLoader.Note
DataLoaderby default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.
- 
class 
torch.utils.data.TensorDataset(*tensors: torch.Tensor)[source]¶ Dataset wrapping tensors.
Each sample will be retrieved by indexing tensors along the first dimension.
- Parameters
 *tensors (Tensor) – tensors that have the same size of the first dimension.
- 
class 
torch.utils.data.ConcatDataset(datasets: Iterable[torch.utils.data.dataset.Dataset])[source]¶ Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets.
- Parameters
 datasets (sequence) – List of datasets to be concatenated
- 
class 
torch.utils.data.Subset(dataset: torch.utils.data.dataset.Dataset[T_co], indices: Sequence[int])[source]¶ Subset of a dataset at specified indices.
- Parameters
 dataset (Dataset) – The whole Dataset
indices (sequence) – Indices in the whole set selected for subset
- 
class 
torch.utils.data.DataLoader(dataset: torch.utils.data.dataset.Dataset[T_co], batch_size: Optional[int] = 1, shuffle: bool = False, sampler: Optional[torch.utils.data.sampler.Sampler[int]] = None, batch_sampler: Optional[torch.utils.data.sampler.Sampler[Sequence[int]]] = None, num_workers: int = 0, collate_fn: Optional[Callable[List[T], Any]] = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[Callable[int, None]] = None, multiprocessing_context=None, generator=None, *, prefetch_factor: int = 2, persistent_workers: bool = False)[source]¶ Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.
The
DataLoadersupports both map-style and iterable-style datasets with single- or multi-process loading, customizing loading order and optional automatic batching (collation) and memory pinning.See
torch.utils.datadocumentation page for more details.- Parameters
 dataset (Dataset) – dataset from which to load the data.
batch_size (int, optional) – how many samples per batch to load (default:
1).shuffle (bool, optional) – set to
Trueto have the data reshuffled at every epoch (default:False).sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any
Iterablewith__len__implemented. If specified,shufflemust not be specified.batch_sampler (Sampler or Iterable, optional) – like
sampler, but returns a batch of indices at a time. Mutually exclusive withbatch_size,shuffle,sampler, anddrop_last.num_workers (int, optional) – how many subprocesses to use for data loading.
0means that the data will be loaded in the main process. (default:0)collate_fn (callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
pin_memory (bool, optional) – If
True, the data loader will copy Tensors into CUDA pinned memory before returning them. If your data elements are a custom type, or yourcollate_fnreturns a batch that is a custom type, see the example below.drop_last (bool, optional) – set to
Trueto drop the last incomplete batch, if the dataset size is not divisible by the batch size. IfFalseand the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default:False)timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default:
0)worker_init_fn (callable, optional) – If not
None, this will be called on each worker subprocess with the worker id (an int in[0, num_workers - 1]) as input, after seeding and before data loading. (default:None)prefetch_factor (int, optional, keyword-only arg) – Number of samples loaded in advance by each worker.
2means there will be a total of 2 * num_workers samples prefetched across all workers. (default:2)persistent_workers (bool, optional) – If
True, the data loader will not shutdown the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. (default:False)
Warning
If the
spawnstart method is used,worker_init_fncannot be an unpicklable object, e.g., a lambda function. See multiprocessing-best-practices on more details related to multiprocessing in PyTorch.Warning
len(dataloader)heuristic is based on the length of the sampler used. Whendatasetis anIterableDataset, it instead returns an estimate based onlen(dataset) / batch_size, with proper rounding depending ondrop_last, regardless of multi-process loading configurations. This represents the best guess PyTorch can make because PyTorch trusts userdatasetcode in correctly handling multi-process loading to avoid duplicate data.However, if sharding results in multiple workers having incomplete last batches, this estimate can still be inaccurate, because (1) an otherwise complete batch can be broken into multiple ones and (2) more than one batch worth of samples can be dropped when
drop_lastis set. Unfortunately, PyTorch can not detect such cases in general.See `Dataset Types`_ for more details on these two types of datasets and how
IterableDatasetinteracts with `Multi-process data loading`_.Warning
See reproducibility, and My data loader workers return identical random numbers, and data-loading-randomness notes for random seed related questions.
- 
torch.utils.data.random_split(dataset: torch.utils.data.dataset.Dataset[T], lengths: Sequence[int], generator: Optional[torch._C.Generator] = <torch._C.Generator object>) → List[torch.utils.data.dataset.Subset[T]][source]¶ Randomly split a dataset into non-overlapping new datasets of given lengths. Optionally fix the generator for reproducible results, e.g.:
>>> random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))
- Parameters
 dataset (Dataset) – Dataset to be split
lengths (sequence) – lengths of splits to be produced
generator (Generator) – Generator used for the random permutation.
- 
class 
torch.utils.data.Sampler(data_source: Optional[collections.abc.Sized])[source]¶ Base class for all Samplers.
Every Sampler subclass has to provide an
__iter__()method, providing a way to iterate over indices of dataset elements, and a__len__()method that returns the length of the returned iterators.Note
The
__len__()method isn’t strictly required byDataLoader, but is expected in any calculation involving the length of aDataLoader.
- 
class 
torch.utils.data.SequentialSampler(data_source)[source]¶ Samples elements sequentially, always in the same order.
- Parameters
 data_source (Dataset) – dataset to sample from
- 
class 
torch.utils.data.RandomSampler(data_source: collections.abc.Sized, replacement: bool = False, num_samples: Optional[int] = None, generator=None)[source]¶ Samples elements randomly. If without replacement, then sample from a shuffled dataset. If with replacement, then user can specify
num_samplesto draw.- Parameters
 data_source (Dataset) – dataset to sample from
replacement (bool) – samples are drawn on-demand with replacement if
True, default=``False``num_samples (int) – number of samples to draw, default=`len(dataset)`. This argument is supposed to be specified only when replacement is
True.generator (Generator) – Generator used in sampling.
- 
class 
torch.utils.data.SubsetRandomSampler(indices: Sequence[int], generator=None)[source]¶ Samples elements randomly from a given list of indices, without replacement.
- Parameters
 indices (sequence) – a sequence of indices
generator (Generator) – Generator used in sampling.
- 
class 
torch.utils.data.WeightedRandomSampler(weights: Sequence[float], num_samples: int, replacement: bool = True, generator=None)[source]¶ Samples elements from
[0,..,len(weights)-1]with given probabilities (weights).- Parameters
 weights (sequence) – a sequence of weights, not necessary summing up to one
num_samples (int) – number of samples to draw
replacement (bool) – if
True, samples are drawn with replacement. If not, they are drawn without replacement, which means that when a sample index is drawn for a row, it cannot be drawn again for that row.generator (Generator) – Generator used in sampling.
Example
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) [4, 4, 1, 4, 5] >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) [0, 1, 4, 3, 2]
- 
class 
torch.utils.data.BatchSampler(sampler: torch.utils.data.sampler.Sampler[int], batch_size: int, drop_last: bool)[source]¶ Wraps another sampler to yield a mini-batch of indices.
- Parameters
 
Example
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
- 
class 
torch.utils.data.distributed.DistributedSampler(dataset: torch.utils.data.dataset.Dataset, num_replicas: Optional[int] = None, rank: Optional[int] = None, shuffle: bool = True, seed: int = 0, drop_last: bool = False)[source]¶ Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
torch.nn.parallel.DistributedDataParallel. In such a case, each process can pass aDistributedSamplerinstance as aDataLoadersampler, and load a subset of the original dataset that is exclusive to it.Note
Dataset is assumed to be of constant size.
- Parameters
 dataset – Dataset used for sampling.
num_replicas (int, optional) – Number of processes participating in distributed training. By default,
world_sizeis retrieved from the current distributed group.rank (int, optional) – Rank of the current process within
num_replicas. By default,rankis retrieved from the current distributed group.shuffle (bool, optional) – If
True(default), sampler will shuffle the indices.seed (int, optional) – random seed used to shuffle the sampler if
shuffle=True. This number should be identical across all processes in the distributed group. Default:0.drop_last (bool, optional) – if
True, then the sampler will drop the tail of the data to make it evenly divisible across the number of replicas. IfFalse, the sampler will add extra indices to make the data evenly divisible across the replicas. Default:False.
Warning
In distributed mode, calling the
set_epoch()method at the beginning of each epoch before creating theDataLoaderiterator is necessary to make shuffling work properly across multiple epochs. Otherwise, the same ordering will be always used.Example:
>>> sampler = DistributedSampler(dataset) if is_distributed else None >>> loader = DataLoader(dataset, shuffle=(sampler is None), ... sampler=sampler) >>> for epoch in range(start_epoch, n_epochs): ... if is_distributed: ... sampler.set_epoch(epoch) ... train(loader)