torch.overrides

This module exposes various helper functions for the __torch_function__ protocol. See Extending torch for more detail on the __torch_function__ protocol.

Functions

torch.overrides.get_ignored_functions() [source]

Return public functions that cannot be overridden by __torch_function__.

Returns

A tuple of functions that are publicly available in the torch API but cannot be overridden with __torch_function__. Mostly this is because none of the arguments of these functions are tensors or tensor-likes.

Return type

Set[Callable]

Examples

>>> torch.Tensor.as_subclass in torch.overrides.get_ignored_functions()
True
>>> torch.add in torch.overrides.get_ignored_functions()
False
torch.overrides.get_overridable_functions() [source]

List functions that are overridable via __torch_function__

Returns

A dictionary that maps namespaces that contain overridable functions to functions in that namespace that can be overridden.

Return type

Dict[Any, List[Callable]]

torch.overrides.get_testing_overrides() [source]

Return a dict containing dummy overrides for all overridable functions

Returns

A dictionary that maps overridable functions in the PyTorch API to lambda functions that have the same signature as the real function and unconditionally return -1. These lambda functions are useful for testing API coverage for a type that defines __torch_function__.

Return type

Dict[Callable, Callable]

Examples

>>> import inspect
>>> my_add = torch.overrides.get_testing_overrides()[torch.add]
>>> inspect.signature(my_add)
<Signature (input, other, out=None)>
torch.overrides.handle_torch_function(public_api, relevant_args, *args, **kwargs) [source]

Implement a function with checks for __torch_function__ overrides.

See torch::autograd::handle_torch_function for the equivalent of this function in the C++ implementation.

Parameters
  • public_api (function) – Function exposed by the public torch API originally called like public_api(*args, **kwargs) on which arguments are now being checked.
  • relevant_args (iterable) – Iterable of arguments to check for __torch_function__ methods.
  • args (tuple) – Arbitrary positional arguments originally passed into public_api.
  • kwargs (tuple) – Arbitrary keyword arguments originally passed into public_api.
Returns

Result from calling implementation or an __torch_function__ method, as appropriate.

Return type

object

:raises TypeError : if no implementation is found.:

Example

>>> def func(a):
...     if type(a) is not torch.Tensor:  # This will make func dispatchable by __torch_function__
...         return handle_torch_function(func, (a,), a)
...     return a + 0
torch.overrides.has_torch_function()

Check for __torch_function__ implementations in the elements of an iterable. Considers exact Tensor s and Parameter s non-dispatchable. :param relevant_args: Iterable or aguments to check for __torch_function__ methods. :type relevant_args: iterable

Returns

True if any of the elements of relevant_args have __torch_function__ implementations, False otherwise.

Return type

bool

See also

torch.is_tensor_like()

Checks if something is a Tensor-like, including an exact Tensor.

torch.overrides.is_tensor_like(inp) [source]

Returns True if the passed-in input is a Tensor-like.

Currently, this occurs whenever there’s a __torch_function__ attribute on the type of the input.

Examples

A subclass of tensor is generally a Tensor-like.

>>> class SubTensor(torch.Tensor): ...
>>> is_tensor_like(SubTensor([0]))
True

Built-in or user types aren’t usually Tensor-like.

>>> is_tensor_like(6)
False
>>> is_tensor_like(None)
False
>>> class NotATensor: ...
>>> is_tensor_like(NotATensor())
False

But, they can be made Tensor-like by implementing __torch_function__.

>>> class TensorLike:
...     def __torch_function__(self, func, types, args, kwargs):
...         return -1
>>> is_tensor_like(TensorLike())
True
torch.overrides.is_tensor_method_or_property(func) [source]

Returns True if the function passed in is a handler for a method or property belonging to torch.Tensor, as passed into __torch_function__.

Note

For properties, their __get__ method must be passed in.

This may be needed, in particular, for the following reasons:

  1. Methods/properties sometimes don’t contain a __module__ slot.
  2. They require that the first passed-in argument is an instance of torch.Tensor.

Examples

>>> is_tensor_method_or_property(torch.Tensor.add)
True
>>> is_tensor_method_or_property(torch.add)
False
torch.overrides.wrap_torch_function(dispatcher) [source]

Wraps a given function with __torch_function__ -related functionality.

Parameters

dispatcher (Callable) – A callable that returns an iterable of Tensor-likes passed into the function.

Note

This decorator may reduce the performance of your code. Generally, it’s enough to express your code as a series of functions that, themselves, support __torch_function__. If you find yourself in the rare situation where this is not the case, e.g. if you’re wrapping a low-level library and you also need it to work for Tensor-likes, then this function is available.

Examples

>>> def dispatcher(a): # Must have the same signature as func
...     return (a,)
>>> @torch.overrides.wrap_torch_function(dispatcher)
>>> def func(a): # This will make func dispatchable by __torch_function__
...     return a + 0

© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.8.0/torch.overrides.html