Code to decode input data from file streams.
The code in this file is adapted from the torchdata and webdatasets packages.
It implements an extension based decoder, which selects a different
decoding function based on the extension. In contrast to the
implementations in torchdata and webdatasets, the extension is
removed from the data field after decoding. This ideally makes
the output format invariant to the exact decoding strategy.
Example
image.jpg will be decoded into a numpy array and will be accessable in the field image
.
image.npy.gz will be decoded into a numpy array which can also be accessed under image
.
ExtensionBasedDecoder
Decode key/data based on extension using a list of handlers.
The input fields are assumed to be instances of
StreamWrapper,
which wrap an underlying file like object.
Source code in ocl/data_decoding.py
| class ExtensionBasedDecoder:
"""Decode key/data based on extension using a list of handlers.
The input fields are assumed to be instances of
[StreamWrapper][torchdata.datapipes.utils.StreamWrapper],
which wrap an underlying file like object.
"""
def __init__(self, *handler: Callable[[str, StreamWrapper], Optional[Any]]):
self.handlers = list(handler) if handler else []
def decode1(self, name, data):
if not data:
return data
new_name, extension = os.path.splitext(name)
if not extension:
return name, data
for f in self.handlers:
result = f(extension, data)
if result is not None:
# Remove decoded part of name.
data = result
name = new_name
# Try to decode next part of name.
new_name, extension = os.path.splitext(name)
if extension == "":
# Stop decoding if there are no further extensions to be handled.
break
return name, data
def decode(self, data: dict):
result = {}
if data is not None:
for k, v in data.items():
if k[0] == "_":
if isinstance(v, StreamWrapper):
data_bytes = v.file_obj.read()
v.autoclose()
v = data_bytes
if isinstance(v, bytes):
v = v.decode("utf-8")
result[k] = v
continue
decoded_key, decoded_data = self.decode1(k, v)
result[decoded_key] = decoded_data
return result
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Decode input dictionary."""
return self.decode(data)
|
__call__
Decode input dictionary.
Source code in ocl/data_decoding.py
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Decode input dictionary."""
return self.decode(data)
|