Skip to content

ocl.data_decoding

Code to decode input data from file streams.

The code in this file is adapted from the torchdata and webdatasets packages. It implements an extension based decoder, which selects a different decoding function based on the extension. In contrast to the implementations in torchdata and webdatasets, the extension is removed from the data field after decoding. This ideally makes the output format invariant to the exact decoding strategy.

Example

image.jpg will be decoded into a numpy array and will be accessable in the field image. image.npy.gz will be decoded into a numpy array which can also be accessed under image.

ExtensionBasedDecoder

Decode key/data based on extension using a list of handlers.

The input fields are assumed to be instances of StreamWrapper, which wrap an underlying file like object.

Source code in ocl/data_decoding.py
class ExtensionBasedDecoder:
    """Decode key/data based on extension using a list of handlers.

    The input fields are assumed to be instances of
    [StreamWrapper][torchdata.datapipes.utils.StreamWrapper],
    which wrap an underlying file like object.
    """

    def __init__(self, *handler: Callable[[str, StreamWrapper], Optional[Any]]):
        self.handlers = list(handler) if handler else []

    def decode1(self, name, data):
        if not data:
            return data

        new_name, extension = os.path.splitext(name)
        if not extension:
            return name, data

        for f in self.handlers:
            result = f(extension, data)
            if result is not None:
                # Remove decoded part of name.
                data = result
                name = new_name
                # Try to decode next part of name.
                new_name, extension = os.path.splitext(name)
                if extension == "":
                    # Stop decoding if there are no further extensions to be handled.
                    break
        return name, data

    def decode(self, data: dict):
        result = {}

        if data is not None:
            for k, v in data.items():
                if k[0] == "_":
                    if isinstance(v, StreamWrapper):
                        data_bytes = v.file_obj.read()
                        v.autoclose()
                        v = data_bytes
                    if isinstance(v, bytes):
                        v = v.decode("utf-8")
                        result[k] = v
                        continue
                decoded_key, decoded_data = self.decode1(k, v)
                result[decoded_key] = decoded_data
        return result

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """Decode input dictionary."""
        return self.decode(data)

__call__

Decode input dictionary.

Source code in ocl/data_decoding.py
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
    """Decode input dictionary."""
    return self.decode(data)