MultiheadAttention
-
class torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None)
[source] -
Allows the model to jointly attend to information from different representation subspaces. See Attention Is All You Need
where .
- Parameters
-
- embed_dim – total dimension of the model.
- num_heads – parallel attention heads.
- dropout – a Dropout layer on attn_output_weights. Default: 0.0.
- bias – add bias as module parameter. Default: True.
- add_bias_kv – add bias to the key and value sequences at dim=0.
- add_zero_attn – add a new batch of zeros to the key and value sequences at dim=1.
- kdim – total number of features in key. Default: None.
- vdim – total number of features in value. Default: None.
Note that if
kdim
andvdim
are None, they will be set toembed_dim
such that query, key, and value have the same number of features.Examples:
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
-
forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None)
[source] -
- Parameters
-
- key, value (query,) – map a query and a set of key-value pairs to an output. See “Attention Is All You Need” for more details.
- key_padding_mask – if provided, specified padding elements in the key will be ignored by the attention. When given a binary mask and a value is True, the corresponding value on the attention layer will be ignored. When given a byte mask and a value is non-zero, the corresponding value on the attention layer will be ignored
- need_weights – output attn_output_weights.
- attn_mask – 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch.
- Shapes for inputs:
-
- query: where L is the target sequence length, N is the batch size, E is the embedding dimension.
- key: , where S is the source sequence length, N is the batch size, E is the embedding dimension.
- value: where S is the source sequence length, N is the batch size, E is the embedding dimension.
- key_padding_mask: where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of
True
will be ignored while the position with the value ofFalse
will be unchanged. -
attn_mask: if a 2D mask: where L is the target sequence length, S is the source sequence length.
If a 3D mask: where N is the batch size, L is the target sequence length, S is the source sequence length.
attn_mask
ensure that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions withTrue
is not allowed to attend whileFalse
values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight.
- Shapes for outputs:
-
- attn_output: where L is the target sequence length, N is the batch size, E is the embedding dimension.
- attn_output_weights: where N is the batch size, L is the target sequence length, S is the source sequence length.
© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.8.0/generated/torch.nn.MultiheadAttention.html