PruningContainer
-
class torch.nn.utils.prune.PruningContainer(*args)
[source] -
Container holding a sequence of pruning methods for iterative pruning. Keeps track of the order in which pruning methods are applied and handles combining successive pruning calls.
Accepts as argument an instance of a BasePruningMethod or an iterable of them.
-
add_pruning_method(method)
[source] -
Adds a child pruning
method
to the container.- Parameters
-
method (subclass of BasePruningMethod) – child pruning method to be added to the container.
-
classmethod apply(module, name, *args, importance_scores=None, **kwargs)
-
Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask.
- Parameters
-
- module (nn.Module) – module containing the tensor to prune
-
name (str) – parameter name within
module
on which pruning will act. -
args – arguments passed on to a subclass of
BasePruningMethod
- importance_scores (torch.Tensor) – tensor of importance scores (of same shape as module parameter) used to compute mask for pruning. The values in this tensor indicate the importance of the corresponding elements in the parameter being pruned. If unspecified or None, the parameter will be used in its place.
-
kwargs – keyword arguments passed on to a subclass of a
BasePruningMethod
-
apply_mask(module)
-
Simply handles the multiplication between the parameter being pruned and the generated mask. Fetches the mask and the original tensor from the module and returns the pruned version of the tensor.
- Parameters
-
module (nn.Module) – module containing the tensor to prune
- Returns
-
pruned version of the input tensor
- Return type
-
pruned_tensor (torch.Tensor)
-
compute_mask(t, default_mask)
[source] -
Applies the latest
method
by computing the new partial masks and returning its combination with thedefault_mask
. The new partial mask should be computed on the entries or channels that were not zeroed out by thedefault_mask
. Which portions of the tensort
the new mask will be calculated from depends on thePRUNING_TYPE
(handled by the type handler):- for ‘unstructured’, the mask will be computed from the raveled list of nonmasked entries;
- for ‘structured’, the mask will be computed from the nonmasked channels in the tensor;
- for ‘global’, the mask will be computed across all entries.
- Parameters
-
-
t (torch.Tensor) – tensor representing the parameter to prune (of same dimensions as
default_mask
). - default_mask (torch.Tensor) – mask from previous pruning iteration.
-
t (torch.Tensor) – tensor representing the parameter to prune (of same dimensions as
- Returns
-
new mask that combines the effects of the
default_mask
and the new mask from the current pruningmethod
(of same dimensions asdefault_mask
andt
). - Return type
-
mask (torch.Tensor)
-
prune(t, default_mask=None, importance_scores=None)
-
Computes and returns a pruned version of input tensor
t
according to the pruning rule specified incompute_mask()
.- Parameters
-
-
t (torch.Tensor) – tensor to prune (of same dimensions as
default_mask
). -
importance_scores (torch.Tensor) – tensor of importance scores (of same shape as
t
) used to compute mask for pruningt
. The values in this tensor indicate the importance of the corresponding elements in thet
that is being pruned. If unspecified or None, the tensort
will be used in its place. - default_mask (torch.Tensor, optional) – mask from previous pruning iteration, if any. To be considered when determining what portion of the tensor that pruning should act on. If None, default to a mask of ones.
-
t (torch.Tensor) – tensor to prune (of same dimensions as
- Returns
-
pruned version of tensor
t
.
-
remove(module)
-
Removes the pruning reparameterization from a module. The pruned parameter named
name
remains permanently pruned, and the parameter namedname+'_orig'
is removed from the parameter list. Similarly, the buffer namedname+'_mask'
is removed from the buffers.Note
Pruning itself is NOT undone or reversed!
-
© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.8.0/generated/torch.nn.utils.prune.PruningContainer.html