Spaces:
Paused
Paused
| r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to | |
| collate samples fetched from dataset into Tensor(s). | |
| These **needs** to be in global scope since Py2 doesn't support serializing | |
| static methods. | |
| `default_collate` and `default_convert` are exposed to users via 'dataloader.py'. | |
| """ | |
| import torch | |
| import re | |
| import collections | |
| from torch._six import string_classes | |
| np_str_obj_array_pattern = re.compile(r'[SaUO]') | |
| def default_convert(data): | |
| r""" | |
| Function that converts each NumPy array element into a :class:`torch.Tensor`. If the input is a `Sequence`, | |
| `Collection`, or `Mapping`, it tries to convert each element inside to a :class:`torch.Tensor`. | |
| If the input is not an NumPy array, it is left unchanged. | |
| This is used as the default function for collation when both `batch_sampler` and | |
| `batch_size` are NOT defined in :class:`~torch.utils.data.DataLoader`. | |
| The general input type to output type mapping is similar to that | |
| of :func:`~torch.utils.data.default_collate`. See the description there for more details. | |
| Args: | |
| data: a single data point to be converted | |
| Examples: | |
| >>> # Example with `int` | |
| >>> default_convert(0) | |
| 0 | |
| >>> # Example with NumPy array | |
| >>> # xdoctest: +SKIP | |
| >>> default_convert(np.array([0, 1])) | |
| tensor([0, 1]) | |
| >>> # Example with NamedTuple | |
| >>> Point = namedtuple('Point', ['x', 'y']) | |
| >>> default_convert(Point(0, 0)) | |
| Point(x=0, y=0) | |
| >>> default_convert(Point(np.array(0), np.array(0))) | |
| Point(x=tensor(0), y=tensor(0)) | |
| >>> # Example with List | |
| >>> default_convert([np.array([0, 1]), np.array([2, 3])]) | |
| [tensor([0, 1]), tensor([2, 3])] | |
| """ | |
| elem_type = type(data) | |
| if isinstance(data, torch.Tensor): | |
| return data | |
| elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ | |
| and elem_type.__name__ != 'string_': | |
| # array of string classes and object | |
| if elem_type.__name__ == 'ndarray' \ | |
| and np_str_obj_array_pattern.search(data.dtype.str) is not None: | |
| return data | |
| return torch.as_tensor(data) | |
| elif isinstance(data, collections.abc.Mapping): | |
| try: | |
| return elem_type({key: default_convert(data[key]) for key in data}) | |
| except TypeError: | |
| # The mapping type may not support `__init__(iterable)`. | |
| return {key: default_convert(data[key]) for key in data} | |
| elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple | |
| return elem_type(*(default_convert(d) for d in data)) | |
| elif isinstance(data, tuple): | |
| return [default_convert(d) for d in data] # Backwards compatibility. | |
| elif isinstance(data, collections.abc.Sequence) and not isinstance(data, string_classes): | |
| try: | |
| return elem_type([default_convert(d) for d in data]) | |
| except TypeError: | |
| # The sequence type may not support `__init__(iterable)` (e.g., `range`). | |
| return [default_convert(d) for d in data] | |
| else: | |
| return data | |
| default_collate_err_msg_format = ( | |
| "default_collate: batch must contain tensors, numpy arrays, numbers, " | |
| "dicts or lists; found {}") | |
| def default_collate(batch): | |
| r""" | |
| Function that takes in a batch of data and puts the elements within the batch | |
| into a tensor with an additional outer dimension - batch size. The exact output type can be | |
| a :class:`torch.Tensor`, a `Sequence` of :class:`torch.Tensor`, a | |
| Collection of :class:`torch.Tensor`, or left unchanged, depending on the input type. | |
| This is used as the default function for collation when | |
| `batch_size` or `batch_sampler` is defined in :class:`~torch.utils.data.DataLoader`. | |
| Here is the general input type (based on the type of the element within the batch) to output type mapping: | |
| * :class:`torch.Tensor` -> :class:`torch.Tensor` (with an added outer dimension batch size) | |
| * NumPy Arrays -> :class:`torch.Tensor` | |
| * `float` -> :class:`torch.Tensor` | |
| * `int` -> :class:`torch.Tensor` | |
| * `str` -> `str` (unchanged) | |
| * `bytes` -> `bytes` (unchanged) | |
| * `Mapping[K, V_i]` -> `Mapping[K, default_collate([V_1, V_2, ...])]` | |
| * `NamedTuple[V1_i, V2_i, ...]` -> `NamedTuple[default_collate([V1_1, V1_2, ...]), | |
| default_collate([V2_1, V2_2, ...]), ...]` | |
| * `Sequence[V1_i, V2_i, ...]` -> `Sequence[default_collate([V1_1, V1_2, ...]), | |
| default_collate([V2_1, V2_2, ...]), ...]` | |
| Args: | |
| batch: a single batch to be collated | |
| Examples: | |
| >>> # Example with a batch of `int`s: | |
| >>> default_collate([0, 1, 2, 3]) | |
| tensor([0, 1, 2, 3]) | |
| >>> # Example with a batch of `str`s: | |
| >>> default_collate(['a', 'b', 'c']) | |
| ['a', 'b', 'c'] | |
| >>> # Example with `Map` inside the batch: | |
| >>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}]) | |
| {'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])} | |
| >>> # Example with `NamedTuple` inside the batch: | |
| >>> # xdoctest: +SKIP | |
| >>> Point = namedtuple('Point', ['x', 'y']) | |
| >>> default_collate([Point(0, 0), Point(1, 1)]) | |
| Point(x=tensor([0, 1]), y=tensor([0, 1])) | |
| >>> # Example with `Tuple` inside the batch: | |
| >>> default_collate([(0, 1), (2, 3)]) | |
| [tensor([0, 2]), tensor([1, 3])] | |
| >>> # Example with `List` inside the batch: | |
| >>> default_collate([[0, 1], [2, 3]]) | |
| [tensor([0, 2]), tensor([1, 3])] | |
| """ | |
| elem = batch[0] | |
| elem_type = type(elem) | |
| if isinstance(elem, torch.Tensor): | |
| out = None | |
| if torch.utils.data.get_worker_info() is not None: | |
| # If we're in a background process, concatenate directly into a | |
| # shared memory tensor to avoid an extra copy | |
| numel = sum(x.numel() for x in batch) | |
| storage = elem.storage()._new_shared(numel, device=elem.device) | |
| out = elem.new(storage).resize_(len(batch), *list(elem.size())) | |
| return torch.stack(batch, 0, out=out) | |
| elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ | |
| and elem_type.__name__ != 'string_': | |
| if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': | |
| # array of string classes and object | |
| if np_str_obj_array_pattern.search(elem.dtype.str) is not None: | |
| raise TypeError(default_collate_err_msg_format.format(elem.dtype)) | |
| return default_collate([torch.as_tensor(b) for b in batch]) | |
| elif elem.shape == (): # scalars | |
| return torch.as_tensor(batch) | |
| elif isinstance(elem, float): | |
| return torch.tensor(batch, dtype=torch.float64) | |
| elif isinstance(elem, int): | |
| return torch.tensor(batch) | |
| elif isinstance(elem, string_classes): | |
| return batch | |
| elif isinstance(elem, collections.abc.Mapping): | |
| try: | |
| return elem_type({key: default_collate([d[key] for d in batch]) for key in elem}) | |
| except TypeError: | |
| # The mapping type may not support `__init__(iterable)`. | |
| return {key: default_collate([d[key] for d in batch]) for key in elem} | |
| elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple | |
| return elem_type(*(default_collate(samples) for samples in zip(*batch))) | |
| elif isinstance(elem, collections.abc.Sequence): | |
| # check to make sure that the elements in batch have consistent size | |
| it = iter(batch) | |
| elem_size = len(next(it)) | |
| if not all(len(elem) == elem_size for elem in it): | |
| raise RuntimeError('each element in list of batch should be of equal size') | |
| transposed = list(zip(*batch)) # It may be accessed twice, so we use a list. | |
| if isinstance(elem, tuple): | |
| return [default_collate(samples) for samples in transposed] # Backwards compatibility. | |
| else: | |
| try: | |
| return elem_type([default_collate(samples) for samples in transposed]) | |
| except TypeError: | |
| # The sequence type may not support `__init__(iterable)` (e.g., `range`). | |
| return [default_collate(samples) for samples in transposed] | |
| raise TypeError(default_collate_err_msg_format.format(elem_type)) |