torch.nn.utils.prune.random_unstructured
-
torch.nn.utils.prune.random_unstructured(module, name, amount)
[source] -
Prunes tensor corresponding to parameter called
name
inmodule
by removing the specifiedamount
of (currently unpruned) units selected at random. Modifies module in place (and also return the modified module) by: 1) adding a named buffer calledname+'_mask'
corresponding to the binary mask applied to the parametername
by the pruning method. 2) replacing the parametername
by its pruned version, while the original (unpruned) parameter is stored in a new parameter namedname+'_orig'
.- Parameters
-
- module (nn.Module) – module containing the tensor to prune
-
name (str) – parameter name within
module
on which pruning will act. -
amount (int or float) – quantity of parameters to prune. If
float
, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. Ifint
, it represents the absolute number of parameters to prune.
- Returns
-
modified (i.e. pruned) version of the input module
- Return type
-
module (nn.Module)
Examples
>>> m = prune.random_unstructured(nn.Linear(2, 3), 'weight', amount=1) >>> torch.sum(m.weight_mask == 0) tensor(1)
© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.8.0/generated/torch.nn.utils.prune.random_unstructured.html