tf.keras.layers.AbstractRNNCell

View source on GitHub

Abstract object representing an RNN cell.

Inherits From: Layer

This is the base class for implementing RNN cells with custom behavior.

Every RNNCell must have the properties below and implement call with the signature (output, next_state) = call(input, state).

Examples:

class MinimalRNNCell(AbstractRNNCell):

  def __init__(self, units, **kwargs):
    self.units = units
    super(MinimalRNNCell, self).__init__(**kwargs)

  @property
  def state_size(self):
    return self.units

  def build(self, input_shape):
    self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
                                  initializer='uniform',
                                  name='kernel')
    self.recurrent_kernel = self.add_weight(
        shape=(self.units, self.units),
        initializer='uniform',
        name='recurrent_kernel')
    self.built = True

  def call(self, inputs, states):
    prev_output = states[0]
    h = K.dot(inputs, self.kernel)
    output = h + K.dot(prev_output, self.recurrent_kernel)
    return output, output

This definition of cell differs from the definition used in the literature. In the literature, 'cell' refers to an object with a single scalar output. This definition refers to a horizontal array of such units.

An RNN cell, in the most abstract setting, is anything that has a state and performs some operation that takes a matrix of inputs. This operation results in an output matrix with self.output_size columns. If self.state_size is an integer, this operation also results in a new state matrix with self.state_size columns. If self.state_size is a (possibly nested tuple of) TensorShape object(s), then it should return a matching structure of Tensors having shape [batch_size].concatenate(s) for each s in self.batch_size.

Attributes
output_size Integer or TensorShape: size of outputs produced by this cell.
state_size size(s) of state(s) used by this cell.

It can be represented by an Integer, a TensorShape or a tuple of Integers or TensorShapes.

Methods

get_initial_state

View source

© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/keras/layers/AbstractRNNCell