tf.keras.initializers.Initializer

Initializer base class: all Keras initializers inherit from this class.

Initializers should implement a __call__ method with the following signature:

def __call__(self, shape, dtype=None, **kwargs):
  # returns a tensor of shape `shape` and dtype `dtype`
  # containing values drawn from a distribution of your choice.

Optionally, you an also implement the method get_config and the class method from_config in order to support serialization -- just like with any Keras object.

Here's a simple example: a random normal initializer.

import tensorflow as tf

class ExampleRandomNormal(tf.keras.initializers.Initializer):

  def __init__(self, mean, stddev):
    self.mean = mean
    self.stddev = stddev

  def __call__(self, shape, dtype=None, **kwargs):
    return tf.random.normal(
        shape, mean=self.mean, stddev=self.stddev, dtype=dtype)

  def get_config(self):  # To support serialization
    return {"mean": self.mean, "stddev": self.stddev}

Note that we don't have to implement from_config in the example above since the constructor arguments of the class the keys in the config returned by get_config are the same. In this case, the default from_config works fine.

Methods

from_config

View source

Instantiates an initializer from a configuration dictionary.

Example:

initializer = RandomUniform(-1, 1)
config = initializer.get_config()
initializer = RandomUniform.from_config(config)
Args
config A Python dictionary, the output of get_config.
Returns
A tf.keras.initializers.Initializer instance.

get_config

View source

Returns the configuration of the initializer as a JSON-serializable dict.

Returns
A JSON-serializable Python dict.

__call__

View source

Returns a tensor object initialized as specified by the initializer.

Args
shape Shape of the tensor.
dtype Optional dtype of the tensor.
**kwargs Additional keyword arguments.

© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.4/api_docs/python/tf/keras/initializers/Initializer