tf.contrib.seq2seq.GreedyEmbeddingHelper
A helper for use during inference.
Inherits From: Helper
tf.contrib.seq2seq.GreedyEmbeddingHelper(
embedding, start_tokens, end_token
)
Uses the argmax of the output (treated as logits) and passes the result through an embedding layer to get the next input.
| Args | |
|---|---|
embedding | A callable that takes a vector tensor of ids (argmax ids), or the params argument for embedding_lookup. The returned tensor will be passed to the decoder input. |
start_tokens | int32 vector shaped [batch_size], the start tokens. |
end_token | int32 scalar, the token that marks end of decoding. |
| Raises | |
|---|---|
ValueError | if start_tokens is not a 1D tensor or end_token is not a scalar. |
| Attributes | |
|---|---|
batch_size | Batch size of tensor returned by sample. Returns a scalar int32 tensor. |
sample_ids_dtype | DType of tensor returned by sample. Returns a DType. |
sample_ids_shape | Shape of tensor returned by sample, excluding the batch dimension. Returns a |
Methods
initialize
initialize(
name=None
)
Returns (initial_finished, initial_inputs).
next_inputs
next_inputs(
time, outputs, state, sample_ids, name=None
)
next_inputs_fn for GreedyEmbeddingHelper.
sample
sample(
time, outputs, state, name=None
)
sample for GreedyEmbeddingHelper.
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/contrib/seq2seq/GreedyEmbeddingHelper