torch.nn.utils.prune.custom_from_mask
-
torch.nn.utils.prune.custom_from_mask(module, name, mask)
[source] -
Prunes tensor corresponding to parameter called
name
inmodule
by applying the pre-computed mask inmask
. Modifies module in place (and also return the modified module) by: 1) adding a named buffer calledname+'_mask'
corresponding to the binary mask applied to the parametername
by the pruning method. 2) replacing the parametername
by its pruned version, while the original (unpruned) parameter is stored in a new parameter namedname+'_orig'
.- Parameters
- Returns
-
modified (i.e. pruned) version of the input module
- Return type
-
module (nn.Module)
Examples
>>> m = prune.custom_from_mask( nn.Linear(5, 3), name='bias', mask=torch.Tensor([0, 1, 0]) ) >>> print(m.bias_mask) tensor([0., 1., 0.])
© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.8.0/generated/torch.nn.utils.prune.custom_from_mask.html