tf.contrib.eager.metrics.Metric

A metric holds state for aggregating statistics over an evaluation run.

Inherits From: CheckpointableBase

Example use with eager execution:

m = SomeMetric(...)
for input in ...:
  m(input)
print(m.result())

Example use with graph execution:

m = SomeMetric(...)
inputs = ... # Some tensors to compute the metric on.
m_update = m(inputs)
# Variables defined in first call, so get the initialization op afterwards.
m_init = m.init_variables()  # or tf.compat.v1.global_variables_initializer()
m_result = m.result()
with tf.compat.v1.Session() as sess:
  sess.run(m_init)
  for input in ...:
    sess.run(m_update)
  print(sess.run(m_result))

Example use with graph execution with placeholders and feed_dict:

m = SomeMetric(...)
m_placeholder = tf.compat.v1.placeholder(...)
m_update = m(m_placeholder)
# Variables defined in first call, so get the initialization op afterwards.
m_init = m.init_variables()  # or tf.compat.v1.global_variables_initializer()
m_result = m.result()
with tf.compat.v1.Session() as sess:
  sess.run(m_init)
  for input in ...:
    sess.run(m_update, feed_dict={m_placeholder: input})
  print(sess.run(m_result))

Descendants will implement:

  • build(): All variables should be created in this method, by calling self.add_variable() as in: self.var = self.add_variable(...) build() will be called in the first invocation of __call__(), with the same arguments passed call().
  • call(): Has all updates to variables, as in: self.var.assign_add(...)
  • result(): Computes and returns a final value for the metric from the variables in self.

Descendants may override aggregate(), but usually won't need to. It adds in the state from a list of metrics of the same type as self. (Default is to sum all the variables.) Note that users should not call aggregate(), it is for use by TensorFlow infrastructure.

Attributes
name
variables

Methods

add_variable

View source

Only for use by descendants of Metric.

aggregate

View source

Adds in the state from a list of metrics.

Default implementation sums all the metric variables.

Args
metrics A list of metrics with the same type as self.
Raises
ValueError If metrics contains invalid data.

build

View source

Method to create variables.

Called by __call__() before call() for the first time.

Args
*args
**kwargs The arguments to the first invocation of __call__(). build() may use the shape and/or dtype of these arguments when deciding how to create variables.

call

View source

Accumulates statistics for the metric. Users should use call instead.

Note: This function is executed as a graph function in graph mode. This means: a) Operations on the same resource are executed in textual order. This should make it easier to do things like add the updated value of a variable to another, for example. b) You don't need to worry about collecting the update ops to execute. All update ops added to the graph by this function will be executed. As a result, code should generally work the same way with graph or eager execution.
Args
*args
**kwargs A mini-batch of inputs to the Metric, as passed to __call__().

init_variables

View source

Initializes this Metric's variables.

Should be called after variables are created in the first execution of __call__(). If using graph execution, the return value should be run() in a session before running the op returned by __call__(). (See example above.)

Returns
If using graph execution, this returns an op to perform the initialization. Under eager execution, the variables are reset to their initial values as a side effect and this function returns None.

result

View source

Computes and returns a final value for the metric.

value

View source

In graph mode returns the result Tensor while in eager the callable.

__call__

View source

Returns op to execute to update this metric for these inputs.

Returns None if eager execution is enabled. Returns a graph-mode function if graph execution is enabled.

Args
*args
**kwargs A mini-batch of inputs to the Metric, passed on to call().

© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/contrib/eager/metrics/Metric