Skip to content

Routing

Magic routed module which allows dynamically routing.

This module is used to wrap classes from arbitrary packages into Routable classes. A routable class is augmented with additional constructor parameters that determine on which elements of a PyTree the methods

should be applied.

This is acomplished using a simple trick: Instead of passing the individual parameters to methods of the class, the original method is wrapped. This wrapped method, then selects the desired input arguments from a inputs argument and forwards these to the original class implementation of the method.

Example
import torch
import routed

non_routed_class = torch.nn.Sigmoid()
routed_class = routed.torch.nn.Sigmoid(input_path="my_sigmoid_source")

example_tensor = torch.randn(100)
inputs = {
    "my_sigmoid_source": example_tensor
}
assert torch.allclose(non_routed_class(example_tensor), routed_class(inputs=inputs))

RoutedClass

Class used to dynamically subclass routed classes.

Any subclasses of this class are automatically patched to support routing of input arguments.

Attributes:

Name Type Description
input_mapping Dict[str, List[str]]

Mapping from parameters of routed functions to paths in the inputs dict.

Source code in routed/__init__.py
class RoutedClass:
    """Class used to dynamically subclass routed classes.

    Any subclasses of this class are automatically patched to support routing of input arguments.

    Attributes:
        input_mapping: Mapping from parameters of routed functions to paths in the inputs dict.

    """

    input_mapping: Dict[str, List[str]]

    def __init__(self, *args, **kwargs):
        self._remove_routed_parameters(kwargs)
        super().__init__(*args, **kwargs)

    def __new__(cls, *args, **kwargs):
        # Patch routed methods.
        # This needs to be done here as they are otherwise not considered methods.
        _routed_methods = _get_routed_methods(cls)
        input_mapping = {}
        for method_name in _routed_methods:
            org_method = getattr(cls, method_name)
            for name in inspect.signature(org_method).parameters:
                path_name = f"{name}_path"
                if path_name in kwargs:
                    input_mapping[name] = kwargs[path_name].split(".")

            setattr(cls, method_name, build_routed_method(org_method))
        instance = super().__new__(cls)
        instance.input_mapping = input_mapping
        return instance

    def _remove_routed_parameters(self, kwargs: Dict[str, Any]):
        for param in self.input_mapping:
            path = f"{param}_path"
            if path in kwargs:
                del kwargs[path]

WrappedModule

Bases: types.ModuleType

Module which automatically patches all classes within it to support routing.

Source code in routed/__init__.py
class WrappedModule(types.ModuleType):
    """Module which automatically patches all classes within it to support routing."""

    def __init__(self, path: str, module):
        super().__init__(path, f"Module with routed versions of {path}")
        self.path = path
        self.module = module

    def __getattr__(self, name):
        try:
            imported = getattr(self.module, name)
        except AttributeError:
            imported = importlib.import_module(f"{self.path}.{name}")
        if isinstance(imported, types.ModuleType):
            return WrappedModule(f"{self.path}.{name}", imported)
        return type(f"{self.path}.Routed{name}", (RoutedClass, imported), {})

build_routed_method

Pass arguments to a function based on the mapping defined in self.input_mapping.

This method supports both filtering for parameters that match the arguments of the wrapped method and passing all arguments defined in input_mapping. If a non-optional argument is missing this will raise an exception. Additional arguments can also be passed to the method to override entries in the input dict. Non-keyword arguments are always directly passed to the method.

Parameters:

Name Type Description Default
method types.MethodType

The method to pass the arguments to.

required
filter_parameters bool

Only pass arguments to wrapped method that match the methods signature. This is practical if different methods require different types of input.

True
Source code in routed/__init__.py
def build_routed_method(
    method: types.MethodType, filter_parameters: bool = True
) -> types.MethodType:
    """Pass arguments to a function based on the mapping defined in `self.input_mapping`.

    This method supports both filtering for parameters that match the arguments of the wrapped
    method and passing all arguments defined in `input_mapping`.  If a non-optional argument is
    missing this will raise an exception.  Additional arguments can also be passed to the method
    to override entries in the input dict.  Non-keyword arguments are always directly passed to
    the method.

    Args:
        method: The method to pass the arguments to.
        filter_parameters: Only pass arguments to wrapped method that match the methods
            signature.  This is practical if different methods require different types of input.

    """
    # Run inspection here to reduce compute time when calling method.
    signature = inspect.signature(method)
    valid_parameters = list(signature.parameters)  # Returns the parameter names.
    valid_parameters = valid_parameters[1:]  # Discard "self".
    # Keep track of default parameters. For these we should not fail if they are not in
    # the input dict.
    with_defaults = [
        name
        for name, param in signature.parameters.items()
        if param.default is not inspect.Parameter.empty
    ]

    @functools.wraps(method)
    def method_with_routing(self: RoutedClass, *args, inputs=None, **kwargs):
        if not inputs:
            inputs = {}
        if self.input_mapping:
            if not inputs:  # Empty dict.
                inputs = kwargs

            routed_inputs = {}
            for input_field, input_path in self.input_mapping.items():
                if filter_parameters and input_field not in valid_parameters:
                    # Skip parameters that are not the function signature.
                    continue
                if input_field in kwargs.keys():
                    # Skip parameters that are directly provided as kwargs.
                    continue
                try:
                    element = tree_utils.get_tree_element(inputs, input_path)
                    routed_inputs[input_field] = element
                except ValueError as e:
                    if input_field in with_defaults:
                        continue
                    else:
                        raise e
            # Support for additional parameters passed via keyword arguments.
            # TODO(hornmax): This is not ideal as it mixes routing args from the input dict
            # and explicitly passed kwargs and thus could lead to collisions.
            for name, element in kwargs.items():
                if filter_parameters and name not in valid_parameters:
                    continue
                else:
                    routed_inputs[name] = element
            return method(self, *args, **routed_inputs)
        else:
            return method(self, *args, **kwargs)

    return method_with_routing