importbisectimportfunctoolsimportwarningsfromtypingimport(Callable,Dict,Generic,Iterable,Iterator,List,Optional,Sequence,Tuple,TypeVar,)# No 'default_generator' in torch/__init__.pyifromtorchimportdefault_generator,randpermfromtorch._utilsimport_accumulatefromtorch.utils.data._typingimport_DataPipeMetafrom...importGenerator,TensorT_co=TypeVar('T_co',covariant=True)T=TypeVar('T')UNTRACABLE_DATAFRAME_PIPES=['batch',# As it returns DataChunks'groupby',# As it returns DataChunks'_dataframes_as_tuples',# As it unpacks DF'trace_as_dataframe',# As it used to mark DF for tracing]classDataChunk(list,Generic[T]):def__init__(self,items):super().__init__(items)self.items=itemsdefas_str(self,indent=''):res=indent+"["+", ".join(str(i)foriiniter(self))+"]"returnresdef__iter__(self)->Iterator[T]:foriinsuper().__iter__():yieldidefraw_iterator(self)->T:foriinself.items:yieldi
[docs]classDataset(Generic[T_co]):r"""An abstract class representing a :class:`Dataset`. All datasets that represent a map from keys to data samples should subclass it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:`__len__`, which is expected to return the size of the dataset by many :class:`~torch.utils.data.Sampler` implementations and the default options of :class:`~torch.utils.data.DataLoader`. .. note:: :class:`~torch.utils.data.DataLoader` by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided. """def__getitem__(self,index)->T_co:raiseNotImplementedErrordef__add__(self,other:'Dataset[T_co]')->'ConcatDataset[T_co]':returnConcatDataset([self,other])
# No `def __len__(self)` default?# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]# in pytorch/torch/utils/data/sampler.pyclassMapDataPipe(Dataset[T_co],metaclass=_DataPipeMeta):r""" Map-style DataPipe. All datasets that represent a map from keys to data samples should subclass this. Subclasses should overwrite :meth:`__getitem__`, supporting fetching a data sample for a given, unique key. Subclasses can also optionally overwrite :meth:`__len__`, which is expected to return the size of the dataset by many :class:`~torch.utils.data.Sampler` implementations and the default options of :class:`~torch.utils.data.DataLoader`. These DataPipes can be invoked in two ways, using the class constructor or applying their functional form onto an existing `MapDataPipe` (recommend, available to most but not all DataPipes). Note: :class:`~torch.utils.data.DataLoader` by default constructs an index sampler that yields integral indices. To make it work with a map-style DataPipe with non-integral indices/keys, a custom sampler must be provided. Example: >>> from torchdata.datapipes.map import SequenceWrapper, Mapper >>> dp = SequenceWrapper(range(10)) >>> map_dp_1 = dp.map(lambda x: x + 1) # Using functional form (recommended) >>> list(map_dp_1) [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] >>> map_dp_2 = Mapper(dp, lambda x: x + 1) # Using class constructor >>> list(map_dp_2) [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] >>> batch_dp = map_dp_1.batch(batch_size=2) >>> list(batch_dp) [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]] """functions:Dict[str,Callable]={}def__getattr__(self,attribute_name):ifattribute_nameinMapDataPipe.functions:function=functools.partial(MapDataPipe.functions[attribute_name],self)returnfunctionelse:raiseAttributeError("'{0}' object has no attribute '{1}".format(self.__class__.__name__,attribute_name))@classmethoddefregister_function(cls,function_name,function):cls.functions[function_name]=function@classmethoddefregister_datapipe_as_function(cls,function_name,cls_to_register):iffunction_nameincls.functions:raiseException("Unable to add DataPipe function name {} as it is already taken".format(function_name))defclass_function(cls,source_dp,*args,**kwargs):result_pipe=cls(source_dp,*args,**kwargs)returnresult_pipefunction=functools.partial(class_function,cls_to_register)cls.functions[function_name]=function
[docs]classIterableDataset(Dataset[T_co]):r"""An iterable Dataset. All datasets that represent an iterable of data samples should subclass it. Such form of datasets is particularly useful when data come from a stream. All subclasses should overwrite :meth:`__iter__`, which would return an iterator of samples in this dataset. When a subclass is used with :class:`~torch.utils.data.DataLoader`, each item in the dataset will be yielded from the :class:`~torch.utils.data.DataLoader` iterator. When :attr:`num_workers > 0`, each worker process will have a different copy of the dataset object, so it is often desired to configure each copy independently to avoid having duplicate data returned from the workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker process, returns information about the worker. It can be used in either the dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's :attr:`worker_init_fn` option to modify each copy's behavior. Example 1: splitting workload across all workers in :meth:`__iter__`:: >>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... worker_info = torch.utils.data.get_worker_info() ... if worker_info is None: # single-process data loading, return the full iterator ... iter_start = self.start ... iter_end = self.end ... else: # in a worker process ... # split workload ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... iter_start = self.start + worker_id * per_worker ... iter_end = min(iter_start + per_worker, self.end) ... return iter(range(iter_start, iter_end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6] >>> # Mult-process loading with two worker processes >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [3, 5, 4, 6] >>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=20))) [3, 4, 5, 6] Example 2: splitting workload across all workers using :attr:`worker_init_fn`:: >>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... return iter(range(self.start, self.end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6] >>> >>> # Directly doing multi-process loading yields duplicate data >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [3, 3, 4, 4, 5, 5, 6, 6] >>> # Define a `worker_init_fn` that configures each dataset copy differently >>> def worker_init_fn(worker_id): ... worker_info = torch.utils.data.get_worker_info() ... dataset = worker_info.dataset # the dataset copy in this worker process ... overall_start = dataset.start ... overall_end = dataset.end ... # configure the dataset to only process the split workload ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... dataset.start = overall_start + worker_id * per_worker ... dataset.end = min(dataset.start + per_worker, overall_end) ... >>> # Mult-process loading with the custom `worker_init_fn` >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) [3, 5, 4, 6] >>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn))) [3, 4, 5, 6] """def__iter__(self)->Iterator[T_co]:raiseNotImplementedErrordef__add__(self,other:Dataset[T_co]):returnChainDataset([self,other])
# No `def __len__(self)` default? Subclasses raise `TypeError` when needed.# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]classIterDataPipe(IterableDataset[T_co],metaclass=_DataPipeMeta):r""" Iterable-style DataPipe. All DataPipes that represent an iterable of data samples should subclass this. This style of DataPipes is particularly useful when data come from a stream, or when the number of samples is too large to fit them all in memory. All subclasses should overwrite :meth:`__iter__`, which would return an iterator of samples in this DataPipe. `IterDataPipe` is lazily initialized and its elements are computed only when ``next()`` is called on its iterator. These DataPipes can be invoked in two ways, using the class constructor or applying their functional form onto an existing `IterDataPipe` (recommended, available to most but not all DataPipes). You can chain multiple `IterDataPipe` together to form a pipeline that will perform multiple operations in succession. Note: When a subclass is used with :class:`~torch.utils.data.DataLoader`, each item in the DataPipe will be yielded from the :class:`~torch.utils.data.DataLoader` iterator. When :attr:`num_workers > 0`, each worker process will have a different copy of the DataPipe object, so it is often desired to configure each copy independently to avoid having duplicate data returned from the workers. :func:`~torch.utils.data.get_worker_info`, when called in a worker process, returns information about the worker. It can be used in either the dataset's :meth:`__iter__` method or the :class:`~torch.utils.data.DataLoader` 's :attr:`worker_init_fn` option to modify each copy's behavior. Example: >>> from torchdata.datapipes.iter import IterableWrapper, Mapper >>> dp = IterableWrapper(range(10)) >>> map_dp_1 = Mapper(dp, lambda x: x + 1) # Using class constructor >>> map_dp_2 = dp.map(lambda x: x + 1) # Using functional form (recommended) >>> list(map_dp_1) [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] >>> list(map_dp_2) [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] >>> filter_dp = map_dp_1.filter(lambda x: x % 2 == 0) >>> list(filter_dp) [2, 4, 6, 8, 10] """functions:Dict[str,Callable]={}reduce_ex_hook:Optional[Callable]=Nonegetstate_hook:Optional[Callable]=Nonedef__getattr__(self,attribute_name):ifattribute_nameinIterDataPipe.functions:function=functools.partial(IterDataPipe.functions[attribute_name],self)returnfunctionelse:raiseAttributeError("'{0}' object has no attribute '{1}".format(self.__class__.__name__,attribute_name))@classmethoddefregister_function(cls,function_name,function):cls.functions[function_name]=function@classmethoddefregister_datapipe_as_function(cls,function_name,cls_to_register,enable_df_api_tracing=False):iffunction_nameincls.functions:raiseException("Unable to add DataPipe function name {} as it is already taken".format(function_name))defclass_function(cls,enable_df_api_tracing,source_dp,*args,**kwargs):result_pipe=cls(source_dp,*args,**kwargs)ifisinstance(result_pipe,IterDataPipe):ifenable_df_api_tracingorisinstance(source_dp,DFIterDataPipe):iffunction_namenotinUNTRACABLE_DATAFRAME_PIPES:result_pipe=result_pipe.trace_as_dataframe()returnresult_pipefunction=functools.partial(class_function,cls_to_register,enable_df_api_tracing)cls.functions[function_name]=functiondef__getstate__(self):ifIterDataPipe.getstate_hookisnotNone:returnIterDataPipe.getstate_hook(self)returnself.__dict__def__reduce_ex__(self,*args,**kwargs):ifIterDataPipe.reduce_ex_hookisnotNone:try:returnIterDataPipe.reduce_ex_hook(self)exceptNotImplementedError:passreturnsuper().__reduce_ex__(*args,**kwargs)@classmethoddefset_getstate_hook(cls,hook_fn):ifIterDataPipe.getstate_hookisnotNoneandhook_fnisnotNone:raiseException("Attempt to override existing getstate_hook")IterDataPipe.getstate_hook=hook_fn@classmethoddefset_reduce_ex_hook(cls,hook_fn):ifIterDataPipe.reduce_ex_hookisnotNoneandhook_fnisnotNone:raiseException("Attempt to override existing reduce_ex_hook")IterDataPipe.reduce_ex_hook=hook_fnclassDFIterDataPipe(IterDataPipe):def_is_dfpipe(self):returnTrue
[docs]classTensorDataset(Dataset[Tuple[Tensor,...]]):r"""Dataset wrapping tensors. Each sample will be retrieved by indexing tensors along the first dimension. Args: *tensors (Tensor): tensors that have the same size of the first dimension. """tensors:Tuple[Tensor,...]def__init__(self,*tensors:Tensor)->None:assertall(tensors[0].size(0)==tensor.size(0)fortensorintensors),"Size mismatch between tensors"self.tensors=tensorsdef__getitem__(self,index):returntuple(tensor[index]fortensorinself.tensors)def__len__(self):returnself.tensors[0].size(0)
[docs]classConcatDataset(Dataset[T_co]):r"""Dataset as a concatenation of multiple datasets. This class is useful to assemble different existing datasets. Args: datasets (sequence): List of datasets to be concatenated """datasets:List[Dataset[T_co]]cumulative_sizes:List[int]@staticmethoddefcumsum(sequence):r,s=[],0foreinsequence:l=len(e)r.append(l+s)s+=lreturnrdef__init__(self,datasets:Iterable[Dataset])->None:super(ConcatDataset,self).__init__()self.datasets=list(datasets)assertlen(self.datasets)>0,'datasets should not be an empty iterable'# type: ignore[arg-type]fordinself.datasets:assertnotisinstance(d,IterableDataset),"ConcatDataset does not support IterableDataset"self.cumulative_sizes=self.cumsum(self.datasets)def__len__(self):returnself.cumulative_sizes[-1]def__getitem__(self,idx):ifidx<0:if-idx>len(self):raiseValueError("absolute value of index should not exceed dataset length")idx=len(self)+idxdataset_idx=bisect.bisect_right(self.cumulative_sizes,idx)ifdataset_idx==0:sample_idx=idxelse:sample_idx=idx-self.cumulative_sizes[dataset_idx-1]returnself.datasets[dataset_idx][sample_idx]@propertydefcummulative_sizes(self):warnings.warn("cummulative_sizes attribute is renamed to ""cumulative_sizes",DeprecationWarning,stacklevel=2)returnself.cumulative_sizes
[docs]classChainDataset(IterableDataset):r"""Dataset for chaining multiple :class:`IterableDataset` s. This class is useful to assemble different existing dataset streams. The chaining operation is done on-the-fly, so concatenating large-scale datasets with this class will be efficient. Args: datasets (iterable of IterableDataset): datasets to be chained together """def__init__(self,datasets:Iterable[Dataset])->None:super(ChainDataset,self).__init__()self.datasets=datasetsdef__iter__(self):fordinself.datasets:assertisinstance(d,IterableDataset),"ChainDataset only supports IterableDataset"forxind:yieldxdef__len__(self):total=0fordinself.datasets:assertisinstance(d,IterableDataset),"ChainDataset only supports IterableDataset"total+=len(d)returntotal
[docs]classSubset(Dataset[T_co]):r""" Subset of a dataset at specified indices. Args: dataset (Dataset): The whole Dataset indices (sequence): Indices in the whole set selected for subset """dataset:Dataset[T_co]indices:Sequence[int]def__init__(self,dataset:Dataset[T_co],indices:Sequence[int])->None:self.dataset=datasetself.indices=indicesdef__getitem__(self,idx):ifisinstance(idx,list):returnself.dataset[[self.indices[i]foriinidx]]returnself.dataset[self.indices[idx]]def__len__(self):returnlen(self.indices)
[docs]defrandom_split(dataset:Dataset[T],lengths:Sequence[int],generator:Optional[Generator]=default_generator)->List[Subset[T]]:r""" Randomly split a dataset into non-overlapping new datasets of given lengths. Optionally fix the generator for reproducible results, e.g.: >>> random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42)) Args: dataset (Dataset): Dataset to be split lengths (sequence): lengths of splits to be produced generator (Generator): Generator used for the random permutation. """# Cannot verify that dataset is Sizedifsum(lengths)!=len(dataset):raiseValueError("Sum of input lengths does not equal the length of the input dataset!")indices=randperm(sum(lengths),generator=generator).tolist()return[Subset(dataset,indices[offset-length:offset])foroffset,lengthinzip(_accumulate(lengths),lengths)]
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.