torch.combinations
-
torch.combinations(input, r=2, with_replacement=False) → seq
-
Compute combinations of length of the given tensor. The behavior is similar to python’s
itertools.combinations
whenwith_replacement
is set toFalse
, anditertools.combinations_with_replacement
whenwith_replacement
is set toTrue
.- Parameters
- Returns
-
A tensor equivalent to converting all the input tensors into lists, do
itertools.combinations
oritertools.combinations_with_replacement
on these lists, and finally convert the resulting list into tensor. - Return type
Example:
>>> a = [1, 2, 3] >>> list(itertools.combinations(a, r=2)) [(1, 2), (1, 3), (2, 3)] >>> list(itertools.combinations(a, r=3)) [(1, 2, 3)] >>> list(itertools.combinations_with_replacement(a, r=2)) [(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)] >>> tensor_a = torch.tensor(a) >>> torch.combinations(tensor_a) tensor([[1, 2], [1, 3], [2, 3]]) >>> torch.combinations(tensor_a, r=3) tensor([[1, 2, 3]]) >>> torch.combinations(tensor_a, with_replacement=True) tensor([[1, 1], [1, 2], [1, 3], [2, 2], [2, 3], [3, 3]])
© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.8.0/generated/torch.combinations.html