Magic routed module which allows dynamically routing.
This module is used to wrap classes from arbitrary packages into Routable classes.
A routable class is augmented with additional constructor parameters that
determine on which elements of a PyTree the methods
should be applied.
This is acomplished using a simple trick: Instead of passing the individual
parameters to methods of the class, the original method is wrapped. This wrapped method,
then selects the desired input arguments from a inputs argument and forwards these to
the original class implementation of the method.
  Example
  import torch
import routed
non_routed_class = torch.nn.Sigmoid()
routed_class = routed.torch.nn.Sigmoid(input_path="my_sigmoid_source")
example_tensor = torch.randn(100)
inputs = {
    "my_sigmoid_source": example_tensor
}
assert torch.allclose(non_routed_class(example_tensor), routed_class(inputs=inputs))
 
  
  
        RoutedClass
  
  
      Class used to dynamically subclass routed classes.
Any subclasses of this class are automatically patched to support routing of input arguments.
  Attributes:
  
    
      
        | Name | Type | Description | 
    
    
        
          | input_mapping | Dict[str, List[str]] | Mapping from parameters of routed functions to paths in the inputs dict. | 
    
  
        
          Source code in routed/__init__.py
          |  | class RoutedClass:
    """Class used to dynamically subclass routed classes.
    Any subclasses of this class are automatically patched to support routing of input arguments.
    Attributes:
        input_mapping: Mapping from parameters of routed functions to paths in the inputs dict.
    """
    input_mapping: Dict[str, List[str]]
    def __init__(self, *args, **kwargs):
        self._remove_routed_parameters(kwargs)
        super().__init__(*args, **kwargs)
    def __new__(cls, *args, **kwargs):
        # Patch routed methods.
        # This needs to be done here as they are otherwise not considered methods.
        _routed_methods = _get_routed_methods(cls)
        input_mapping = {}
        for method_name in _routed_methods:
            org_method = getattr(cls, method_name)
            for name in inspect.signature(org_method).parameters:
                path_name = f"{name}_path"
                if path_name in kwargs:
                    input_mapping[name] = kwargs[path_name].split(".")
            setattr(cls, method_name, build_routed_method(org_method))
        instance = super().__new__(cls)
        instance.input_mapping = input_mapping
        return instance
    def _remove_routed_parameters(self, kwargs: Dict[str, Any]):
        for param in self.input_mapping:
            path = f"{param}_path"
            if path in kwargs:
                del kwargs[path]
 | 
 
  
  
  
   
 
        WrappedModule
  
      
        Bases: types.ModuleType
  
      Module which automatically patches all classes within it to support routing.
        
          Source code in routed/__init__.py
          |  | class WrappedModule(types.ModuleType):
    """Module which automatically patches all classes within it to support routing."""
    def __init__(self, path: str, module):
        super().__init__(path, f"Module with routed versions of {path}")
        self.path = path
        self.module = module
    def __getattr__(self, name):
        try:
            imported = getattr(self.module, name)
        except AttributeError:
            imported = importlib.import_module(f"{self.path}.{name}")
        if isinstance(imported, types.ModuleType):
            return WrappedModule(f"{self.path}.{name}", imported)
        return type(f"{self.path}.Routed{name}", (RoutedClass, imported), {})
 | 
 
  
  
  
   
 
build_routed_method
  
  
      Pass arguments to a function based on the mapping defined in self.input_mapping.
This method supports both filtering for parameters that match the arguments of the wrapped
method and passing all arguments defined in input_mapping.  If a non-optional argument is
missing this will raise an exception.  Additional arguments can also be passed to the method
to override entries in the input dict.  Non-keyword arguments are always directly passed to
the method.
  Parameters:
  
    
      
        | Name | Type | Description | Default | 
    
    
        
          | method | types.MethodType | The method to pass the arguments to. | required | 
        
          | filter_parameters | bool | Only pass arguments to wrapped method that match the methods
signature.  This is practical if different methods require different types of input. | True | 
    
  
      
        Source code in routed/__init__.py
        |  | def build_routed_method(
    method: types.MethodType, filter_parameters: bool = True
) -> types.MethodType:
    """Pass arguments to a function based on the mapping defined in `self.input_mapping`.
    This method supports both filtering for parameters that match the arguments of the wrapped
    method and passing all arguments defined in `input_mapping`.  If a non-optional argument is
    missing this will raise an exception.  Additional arguments can also be passed to the method
    to override entries in the input dict.  Non-keyword arguments are always directly passed to
    the method.
    Args:
        method: The method to pass the arguments to.
        filter_parameters: Only pass arguments to wrapped method that match the methods
            signature.  This is practical if different methods require different types of input.
    """
    # Run inspection here to reduce compute time when calling method.
    signature = inspect.signature(method)
    valid_parameters = list(signature.parameters)  # Returns the parameter names.
    valid_parameters = valid_parameters[1:]  # Discard "self".
    # Keep track of default parameters. For these we should not fail if they are not in
    # the input dict.
    with_defaults = [
        name
        for name, param in signature.parameters.items()
        if param.default is not inspect.Parameter.empty
    ]
    @functools.wraps(method)
    def method_with_routing(self: RoutedClass, *args, inputs=None, **kwargs):
        if not inputs:
            inputs = {}
        if self.input_mapping:
            if not inputs:  # Empty dict.
                inputs = kwargs
            routed_inputs = {}
            for input_field, input_path in self.input_mapping.items():
                if filter_parameters and input_field not in valid_parameters:
                    # Skip parameters that are not the function signature.
                    continue
                if input_field in kwargs.keys():
                    # Skip parameters that are directly provided as kwargs.
                    continue
                try:
                    element = tree_utils.get_tree_element(inputs, input_path)
                    routed_inputs[input_field] = element
                except ValueError as e:
                    if input_field in with_defaults:
                        continue
                    else:
                        raise e
            # Support for additional parameters passed via keyword arguments.
            # TODO(hornmax): This is not ideal as it mixes routing args from the input dict
            # and explicitly passed kwargs and thus could lead to collisions.
            for name, element in kwargs.items():
                if filter_parameters and name not in valid_parameters:
                    continue
                else:
                    routed_inputs[name] = element
            return method(self, *args, **routed_inputs)
        else:
            return method(self, *args, **kwargs)
    return method_with_routing
 |