tf.contrib.training.rejection_sample
Stochastically creates batches by rejection sampling.
tf.contrib.training.rejection_sample( tensors, accept_prob_fn, batch_size, queue_threads=1, enqueue_many=False, prebatch_capacity=16, prebatch_threads=1, runtime_checks=False, name=None )
Each list of non-batched tensors is evaluated by accept_prob_fn
, to produce a scalar tensor between 0 and 1. This tensor corresponds to the probability of being accepted. When batch_size
tensor groups have been accepted, the batch queue will return a mini-batch.
Args | |
---|---|
tensors | List of tensors for data. All tensors are either one item or a batch, according to enqueue_many. |
accept_prob_fn | A python lambda that takes a non-batch tensor from each item in tensors , and produces a scalar tensor. |
batch_size | Size of batch to be returned. |
queue_threads | The number of threads for the queue that will hold the final batch. |
enqueue_many | Bool. If true, interpret input tensors as having a batch dimension. |
prebatch_capacity | Capacity for the large queue that is used to convert batched tensors to single examples. |
prebatch_threads | Number of threads for the large queue that is used to convert batched tensors to single examples. |
runtime_checks | Bool. If true, insert runtime checks on the output of accept_prob_fn . Using True might have a performance impact. |
name | Optional prefix for ops created by this function. |
Raises | |
---|---|
ValueError | enqueue_many is True and labels doesn't have a batch dimension, or if enqueue_many is False and labels isn't a scalar. |
ValueError | enqueue_many is True, and batch dimension on data and labels don't match. |
ValueError | if a zero initial probability class has a nonzero target probability. |
Returns | |
---|---|
A list of tensors of the same length as tensors , with batch dimension batch_size . |
Example:
Get tensor for a single data and label example.
data, label = data_provider.Get(['data', 'label'])
Get stratified batch according to data tensor.
accept_prob_fn = lambda x: (tf.tanh(x[0]) + 1) / 2 data_batch = tf.contrib.training.rejection_sample( [data, label], accept_prob_fn, 16)
Run batch through network.
...
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/contrib/training/rejection_sample