torch.overrides
This module exposes various helper functions for the __torch_function__
protocol. See Extending torch for more detail on the __torch_function__
protocol.
Functions
-
torch.overrides.get_ignored_functions()
[source] -
Return public functions that cannot be overridden by
__torch_function__
.- Returns
-
A tuple of functions that are publicly available in the torch API but cannot be overridden with
__torch_function__
. Mostly this is because none of the arguments of these functions are tensors or tensor-likes. - Return type
-
Set[Callable]
Examples
>>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions() True >>> torch.add in torch.overrides.get_ignored_functions() False
-
torch.overrides.get_overridable_functions()
[source] -
List functions that are overridable via __torch_function__
- Returns
-
A dictionary that maps namespaces that contain overridable functions to functions in that namespace that can be overridden.
- Return type
-
Dict[Any, List[Callable]]
-
torch.overrides.get_testing_overrides()
[source] -
Return a dict containing dummy overrides for all overridable functions
- Returns
-
A dictionary that maps overridable functions in the PyTorch API to lambda functions that have the same signature as the real function and unconditionally return -1. These lambda functions are useful for testing API coverage for a type that defines
__torch_function__
. - Return type
-
Dict[Callable, Callable]
Examples
>>> import inspect >>> my_add = torch.overrides.get_testing_overrides()[torch.add] >>> inspect.signature(my_add) <Signature (input, other, out=None)>
-
torch.overrides.handle_torch_function(public_api, relevant_args, *args, **kwargs)
[source] -
Implement a function with checks for
__torch_function__
overrides.See torch::autograd::handle_torch_function for the equivalent of this function in the C++ implementation.
- Parameters
-
-
public_api (function) – Function exposed by the public torch API originally called like
public_api(*args, **kwargs)
on which arguments are now being checked. - relevant_args (iterable) – Iterable of arguments to check for __torch_function__ methods.
-
args (tuple) – Arbitrary positional arguments originally passed into
public_api
. -
kwargs (tuple) – Arbitrary keyword arguments originally passed into
public_api
.
-
public_api (function) – Function exposed by the public torch API originally called like
- Returns
-
Result from calling
implementation
or an__torch_function__
method, as appropriate. - Return type
:raises TypeError : if no implementation is found.:
Example
>>> def func(a): ... if type(a) is not torch.Tensor: # This will make func dispatchable by __torch_function__ ... return handle_torch_function(func, (a,), a) ... return a + 0
-
torch.overrides.has_torch_function()
-
Check for __torch_function__ implementations in the elements of an iterable. Considers exact
Tensor
s andParameter
s non-dispatchable. :param relevant_args: Iterable or aguments to check for __torch_function__ methods. :type relevant_args: iterable- Returns
-
True if any of the elements of relevant_args have __torch_function__ implementations, False otherwise.
- Return type
See also
-
torch.is_tensor_like()
-
Checks if something is a Tensor-like, including an exact
Tensor
.
-
torch.overrides.is_tensor_like(inp)
[source] -
Returns
True
if the passed-in input is a Tensor-like.Currently, this occurs whenever there’s a
__torch_function__
attribute on the type of the input.Examples
A subclass of tensor is generally a Tensor-like.
>>> class SubTensor(torch.Tensor): ... >>> is_tensor_like(SubTensor([0])) True
Built-in or user types aren’t usually Tensor-like.
>>> is_tensor_like(6) False >>> is_tensor_like(None) False >>> class NotATensor: ... >>> is_tensor_like(NotATensor()) False
But, they can be made Tensor-like by implementing __torch_function__.
>>> class TensorLike: ... def __torch_function__(self, func, types, args, kwargs): ... return -1 >>> is_tensor_like(TensorLike()) True
-
torch.overrides.is_tensor_method_or_property(func)
[source] -
Returns True if the function passed in is a handler for a method or property belonging to
torch.Tensor
, as passed into__torch_function__
.Note
For properties, their
__get__
method must be passed in.This may be needed, in particular, for the following reasons:
- Methods/properties sometimes don’t contain a
__module__
slot. - They require that the first passed-in argument is an instance of
torch.Tensor
.
Examples
>>> is_tensor_method_or_property(torch.Tensor.add) True >>> is_tensor_method_or_property(torch.add) False
- Methods/properties sometimes don’t contain a
-
torch.overrides.wrap_torch_function(dispatcher)
[source] -
Wraps a given function with
__torch_function__
-related functionality.- Parameters
-
dispatcher (Callable) – A callable that returns an iterable of Tensor-likes passed into the function.
Note
This decorator may reduce the performance of your code. Generally, it’s enough to express your code as a series of functions that, themselves, support __torch_function__. If you find yourself in the rare situation where this is not the case, e.g. if you’re wrapping a low-level library and you also need it to work for Tensor-likes, then this function is available.
Examples
>>> def dispatcher(a): # Must have the same signature as func ... return (a,) >>> @torch.overrides.wrap_torch_function(dispatcher) >>> def func(a): # This will make func dispatchable by __torch_function__ ... return a + 0
© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.8.0/torch.overrides.html