torch.multinomial
-
torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None) → LongTensor -
Returns a tensor where each row contains
num_samplesindices sampled from the multinomial probability distribution located in the corresponding row of tensorinput.Note
The rows of
inputdo not need to sum to one (in which case we use the values as weights), but must be non-negative, finite and have a non-zero sum.Indices are ordered from left to right according to when each was sampled (first samples are placed in first column).
If
inputis a vector,outis a vector of sizenum_samples.If
inputis a matrix withmrows,outis an matrix of shape .If replacement is
True, samples are drawn with replacement.If not, they are drawn without replacement, which means that when a sample index is drawn for a row, it cannot be drawn again for that row.
Note
When drawn without replacement,
num_samplesmust be lower than number of non-zero elements ininput(or the min number of non-zero elements in each row ofinputif it is a matrix).- Parameters
- Keyword Arguments
-
-
generator (
torch.Generator, optional) – a pseudorandom number generator for sampling - out (Tensor, optional) – the output tensor.
-
generator (
Example:
>>> weights = torch.tensor([0, 10, 3, 0], dtype=torch.float) # create a tensor of weights >>> torch.multinomial(weights, 2) tensor([1, 2]) >>> torch.multinomial(weights, 4) # ERROR! RuntimeError: invalid argument 2: invalid multinomial distribution (with replacement=False, not enough non-negative category to sample) at ../aten/src/TH/generic/THTensorRandom.cpp:320 >>> torch.multinomial(weights, 4, replacement=True) tensor([ 2, 1, 1, 1])
© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.8.0/generated/torch.multinomial.html