tf.contrib.seq2seq.GreedyEmbeddingHelper
A helper for use during inference.
Inherits From: Helper
tf.contrib.seq2seq.GreedyEmbeddingHelper( embedding, start_tokens, end_token )
Uses the argmax of the output (treated as logits) and passes the result through an embedding layer to get the next input.
Args | |
---|---|
embedding | A callable that takes a vector tensor of ids (argmax ids), or the params argument for embedding_lookup . The returned tensor will be passed to the decoder input. |
start_tokens | int32 vector shaped [batch_size] , the start tokens. |
end_token | int32 scalar, the token that marks end of decoding. |
Raises | |
---|---|
ValueError | if start_tokens is not a 1D tensor or end_token is not a scalar. |
Attributes | |
---|---|
batch_size | Batch size of tensor returned by sample . Returns a scalar int32 tensor. |
sample_ids_dtype | DType of tensor returned by sample . Returns a DType. |
sample_ids_shape | Shape of tensor returned by sample , excluding the batch dimension. Returns a |
Methods
initialize
initialize( name=None )
Returns (initial_finished, initial_inputs)
.
next_inputs
next_inputs( time, outputs, state, sample_ids, name=None )
next_inputs_fn for GreedyEmbeddingHelper.
sample
sample( time, outputs, state, name=None )
sample for GreedyEmbeddingHelper.
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/contrib/seq2seq/GreedyEmbeddingHelper