tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper

A training helper that adds scheduled sampling.

Inherits From: TrainingHelper

Returns -1s for sample_ids where no sampling took place; valid sample id values elsewhere.

Args
inputs A (structure of) input tensors.
sequence_length An int32 vector tensor.
embedding A callable that takes a vector tensor of ids (argmax ids), or the params argument for embedding_lookup.
sampling_probability A 0D float32 tensor: the probability of sampling categorically from the output ids instead of reading directly from the inputs.
time_major Python bool. Whether the tensors in inputs are time major. If False (default), they are assumed to be batch major.
seed The sampling seed.
scheduling_seed The schedule decision rule sampling seed.
name Name scope for any created operations.
Raises
ValueError if sampling_probability is not a scalar or vector.
Attributes
batch_size Batch size of tensor returned by sample.

Returns a scalar int32 tensor.

inputs
sample_ids_dtype DType of tensor returned by sample.

Returns a DType.

sample_ids_shape Shape of tensor returned by sample, excluding the batch dimension.

Returns a TensorShape.

sequence_length

Methods

initialize

View source

Returns (initial_finished, initial_inputs).

next_inputs

View source

next_inputs_fn for TrainingHelper.

sample

View source

Returns sample_ids.

© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/contrib/seq2seq/ScheduledEmbeddingTrainingHelper