torch.onnx

Example: End-to-end AlexNet from PyTorch to ONNX

Here is a simple script which exports a pretrained AlexNet as defined in torchvision into ONNX. It runs a single round of inference and then saves the resulting traced model to alexnet.onnx:

import torch
import torchvision

dummy_input = torch.randn(10, 3, 224, 224, device='cuda')
model = torchvision.models.alexnet(pretrained=True).cuda()

# Providing input and output names sets the display names for values
# within the model's graph. Setting these does not change the semantics
# of the graph; it is only for readability.
#
# The inputs to the network consist of the flat list of inputs (i.e.
# the values you would pass to the forward() method) followed by the
# flat list of parameters. You can partially specify names, i.e. provide
# a list here shorter than the number of inputs to the model, and we will
# only set that subset of names, starting from the beginning.
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]

torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)

The resulting alexnet.onnx is a binary protobuf file which contains both the network structure and parameters of the model you exported (in this case, AlexNet). The keyword argument verbose=True causes the exporter to print out a human-readable representation of the network:

# These are the inputs and parameters to the network, which have taken on
# the names we specified earlier.
graph(%actual_input_1 : Float(10, 3, 224, 224)
      %learned_0 : Float(64, 3, 11, 11)
      %learned_1 : Float(64)
      %learned_2 : Float(192, 64, 5, 5)
      %learned_3 : Float(192)
      # ---- omitted for brevity ----
      %learned_14 : Float(1000, 4096)
      %learned_15 : Float(1000)) {
  # Every statement consists of some output tensors (and their types),
  # the operator to be run (with its attributes, e.g., kernels, strides,
  # etc.), its input tensors (%actual_input_1, %learned_0, %learned_1)
  %17 : Float(10, 64, 55, 55) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[11, 11], pads=[2, 2, 2, 2], strides=[4, 4]](%actual_input_1, %learned_0, %learned_1), scope: AlexNet/Sequential[features]/Conv2d[0]
  %18 : Float(10, 64, 55, 55) = onnx::Relu(%17), scope: AlexNet/Sequential[features]/ReLU[1]
  %19 : Float(10, 64, 27, 27) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%18), scope: AlexNet/Sequential[features]/MaxPool2d[2]
  # ---- omitted for brevity ----
  %29 : Float(10, 256, 6, 6) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%28), scope: AlexNet/Sequential[features]/MaxPool2d[12]
  # Dynamic means that the shape is not known. This may be because of a
  # limitation of our implementation (which we would like to fix in a
  # future release) or shapes which are truly dynamic.
  %30 : Dynamic = onnx::Shape(%29), scope: AlexNet
  %31 : Dynamic = onnx::Slice[axes=[0], ends=[1], starts=[0]](%30), scope: AlexNet
  %32 : Long() = onnx::Squeeze[axes=[0]](%31), scope: AlexNet
  %33 : Long() = onnx::Constant[value={9216}](), scope: AlexNet
  # ---- omitted for brevity ----
  %output1 : Float(10, 1000) = onnx::Gemm[alpha=1, beta=1, broadcast=1, transB=1](%45, %learned_14, %learned_15), scope: AlexNet/Sequential[classifier]/Linear[6]
  return (%output1);
}

You can also verify the protobuf using the ONNX library. You can install ONNX with conda:

conda install -c conda-forge onnx

Then, you can run:

import onnx

# Load the ONNX model
model = onnx.load("alexnet.onnx")

# Check that the IR is well formed
onnx.checker.check_model(model)

# Print a human readable representation of the graph
onnx.helper.printable_graph(model.graph)

To run the exported script with caffe2, you will need to install caffe2: If you don’t have one already, Please follow the install instructions.

Once these are installed, you can use the backend for Caffe2:

# ...continuing from above
import caffe2.python.onnx.backend as backend
import numpy as np

rep = backend.prepare(model, device="CUDA:0") # or "CPU"
# For the Caffe2 backend:
#     rep.predict_net is the Caffe2 protobuf for the network
#     rep.workspace is the Caffe2 workspace for the network
#       (see the class caffe2.python.onnx.backend.Workspace)
outputs = rep.run(np.random.randn(10, 3, 224, 224).astype(np.float32))
# To run networks with more than one input, pass a tuple
# rather than a single numpy ndarray.
print(outputs[0])

You can also run the exported model with ONNX Runtime, you will need to install ONNX Runtime: please follow these instructions.

Once these are installed, you can use the backend for ONNX Runtime:

# ...continuing from above
import onnxruntime as ort

ort_session = ort.InferenceSession('alexnet.onnx')

outputs = ort_session.run(None, {'actual_input_1': np.random.randn(10, 3, 224, 224).astype(np.float32)})

print(outputs[0])

Here is another tutorial of exporting the SuperResolution model to ONNX..

In the future, there will be backends for other frameworks as well.

Tracing vs Scripting

The ONNX exporter can be both trace-based and script-based exporter.

  • trace-based means that it operates by executing your model once, and exporting the operators which were actually run during this run. This means that if your model is dynamic, e.g., changes behavior depending on input data, the export won’t be accurate. Similarly, a trace is likely to be valid only for a specific input size (which is one reason why we require explicit inputs on tracing.) We recommend examining the model trace and making sure the traced operators look reasonable. If your model contains control flows like for loops and if conditions, trace-based exporter will unroll the loops and if conditions, exporting a static graph that is exactly the same as this run. If you want to export your model with dynamic control flows, you will need to use the script-based exporter.
  • script-based means that the model you are trying to export is a ScriptModule. ScriptModule is the core data structure in TorchScript, and TorchScript is a subset of Python language, that creates serializable and optimizable models from PyTorch code.

We allow mixing tracing and scripting. You can compose tracing and scripting to suit the particular requirements of a part of a model. Checkout this example:

import torch

# Trace-based only

class LoopModel(torch.nn.Module):
    def forward(self, x, y):
        for i in range(y):
            x = x + i
        return x

model = LoopModel()
dummy_input = torch.ones(2, 3, dtype=torch.long)
loop_count = torch.tensor(5, dtype=torch.long)

torch.onnx.export(model, (dummy_input, loop_count), 'loop.onnx', verbose=True)

With trace-based exporter, we get the result ONNX graph which unrolls the for loop:

graph(%0 : Long(2, 3),
      %1 : Long()):
  %2 : Tensor = onnx::Constant[value={1}]()
  %3 : Tensor = onnx::Add(%0, %2)
  %4 : Tensor = onnx::Constant[value={2}]()
  %5 : Tensor = onnx::Add(%3, %4)
  %6 : Tensor = onnx::Constant[value={3}]()
  %7 : Tensor = onnx::Add(%5, %6)
  %8 : Tensor = onnx::Constant[value={4}]()
  %9 : Tensor = onnx::Add(%7, %8)
  return (%9)

To utilize script-based exporter for capturing the dynamic loop, we can write the loop in script, and call it from the regular nn.Module:

# Mixing tracing and scripting

@torch.jit.script
def loop(x, y):
    for i in range(int(y)):
        x = x + i
    return x

class LoopModel2(torch.nn.Module):
    def forward(self, x, y):
        return loop(x, y)

model = LoopModel2()
dummy_input = torch.ones(2, 3, dtype=torch.long)
loop_count = torch.tensor(5, dtype=torch.long)
torch.onnx.export(model, (dummy_input, loop_count), 'loop.onnx', verbose=True,
                  input_names=['input_data', 'loop_range'])

Now the exported ONNX graph becomes:

graph(%input_data : Long(2, 3),
      %loop_range : Long()):
  %2 : Long() = onnx::Constant[value={1}](), scope: LoopModel2/loop
  %3 : Tensor = onnx::Cast[to=9](%2)
  %4 : Long(2, 3) = onnx::Loop(%loop_range, %3, %input_data), scope: LoopModel2/loop # custom_loop.py:240:5
    block0(%i.1 : Long(), %cond : bool, %x.6 : Long(2, 3)):
      %8 : Long(2, 3) = onnx::Add(%x.6, %i.1), scope: LoopModel2/loop # custom_loop.py:241:13
      %9 : Tensor = onnx::Cast[to=9](%2)
      -> (%9, %8)
  return (%4)

The dynamic control flow is captured correctly. We can verify in backends with different loop range.

import caffe2.python.onnx.backend as backend
import numpy as np
import onnx
model = onnx.load('loop.onnx')

rep = backend.prepare(model)
outputs = rep.run((dummy_input.numpy(), np.array(9).astype(np.int64)))
print(outputs[0])
#[[37 37 37]
# [37 37 37]]


import onnxruntime as ort
ort_sess = ort.InferenceSession('loop.onnx')
outputs = ort_sess.run(None, {'input_data': dummy_input.numpy(),
                              'loop_range': np.array(9).astype(np.int64)})
print(outputs)
#[array([[37, 37, 37],
#       [37, 37, 37]], dtype=int64)]

To avoid exporting a variable scalar tensor as a fixed value constant as part of the ONNX model, please avoid use of torch.Tensor.item(). Torch supports implicit cast of single-element tensors to numbers. E.g.:

class LoopModel(torch.nn.Module):
    def forward(self, x, y):
        res = []
        arr = x.split(2, 0)
        for i in range(int(y)):
            res += [arr[i].sum(0, False)]
        return torch.stack(res)

model = torch.jit.script(LoopModel())
inputs = (torch.randn(16), torch.tensor(8))

out = model(*inputs)
torch.onnx.export(model, inputs, 'loop_and_list.onnx', opset_version=11, example_outputs=out)

Write PyTorch model in Torch way

PyTorch models can be written using numpy manipulations, but this is not proper when we convert to the ONNX model. For the trace-based exporter, tracing treats the numpy values as the constant node, therefore it calculates the wrong result if we change the input. So the PyTorch model need implement using torch operators. For example, do not use numpy operators on numpy tensors:

np.concatenate((x, y, z), axis=1)

do not convert to numpy types:

y = x.astype(np.int)

Always use torch tensors and torch operators: torch.concat, etc. In addition, Dropout layer need defined in init function so that inferencing can handle it properly, i.e.,

class MyModule(nn.Module):
    def __init__(self):
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.dropout(x)

Using dictionaries to handle Named Arguments as model inputs

There are two ways to handle models which consist of named parameters or keyword arguments as inputs:

  • The first method is to pass all the inputs in the same order as required by the model and pass None values for the keyword arguments that do not require a value to be passed
  • The second and more intuitive method is to represent the keyword arguments as key-value pairs where the key represents the name of the argument in the model signature and the value represents the value of the argument to be passed

For example, in the model:

class Model(torch.nn.Module):
  def forward(self, x, y=None, z=None):
    if y is not None:
      return x + y
    if z is not None:
      return x + z
    return x
m = Model()
x = torch.randn(2, 3)
z = torch.randn(2, 3)

There are two ways of exporting the model:

  • Not using a dictionary for the keyword arguments and passing all the inputs in the same order as required by the model

    torch.onnx.export(model, (x, None, z), ‘test.onnx’)
    
  • Using a dictionary to represent the keyword arguments. This dictionary is always passed in addition to the non-keyword arguments and is always the last argument in the args tuple.

    torch.onnx.export(model, (x, {'y': None, 'z': z}), ‘test.onnx’)
    

For cases in which there are no keyword arguments, models can be exported with either an empty or no dictionary. For example,

torch.onnx.export(model, (x, {}), ‘test.onnx’)
or
torch.onnx.export(model, (x, ), ‘test.onnx’)

An exception to this rule are cases in which the last input is also of a dictionary type. In these cases it is mandatory to have an empty dictionary as the last argument in the args tuple. For example,

class Model(torch.nn.Module):
  def forward(self, k, x):
    ...
    return x
m = Model()
k = torch.randn(2, 3)
x = {torch.tensor(1.): torch.randn(2, 3)}

Without the presence of the empty dictionary, the export call assumes that the ‘x’ input is intended to represent the optional dictionary consisting of named arguments. In order to prevent this from being an issue a constraint is placed to provide an empty dictionary as the last input in the tuple args in such cases. The new call would look like this.

torch.onnx.export(model, (k, x, {}), ‘test.onnx’)

Indexing

Tensor indexing in PyTorch is very flexible and complicated. There are two categories of indexing. Both are largely supported in exporting today. If you are experiencing issues exporting indexing that belongs to the supported patterns below, please double check that you are exporting with the latest opset (opset_version=12).

Getter

This type of indexing occurs on the RHS. Export is supported for ONNX opset version >= 9. E.g.:

data = torch.randn(3, 4)
index = torch.tensor([1, 2])

# RHS indexing is supported in ONNX opset >= 11.
class RHSIndexing(torch.nn.Module):
    def forward(self, data, index):
        return data[index]

out = RHSIndexing()(data, index)

torch.onnx.export(RHSIndexing(), (data, index), 'indexing.onnx', opset_version=9)

# onnxruntime
import onnxruntime
sess = onnxruntime.InferenceSession('indexing.onnx')
out_ort = sess.run(None, {
    sess.get_inputs()[0].name: data.numpy(),
    sess.get_inputs()[1].name: index.numpy(),
})

assert torch.all(torch.eq(out, torch.tensor(out_ort)))

Below is the list of supported patterns for RHS indexing.

# Scalar indices
data[0, 1]

# Slice indices
data[:3]

# Tensor indices
data[torch.tensor([[1, 2], [2, 3]])]
data[torch.tensor([2, 3]), torch.tensor([1, 2])]
data[torch.tensor([[1, 2], [2, 3]]), torch.tensor([2, 3])]
data[torch.tensor([2, 3]), :, torch.tensor([1, 2])]

# Ellipsis
# Not supported in scripting
# i.e. torch.jit.script(model) will fail if model contains this pattern.
# Export is supported under tracing
# i.e. torch.onnx.export(model)
data[...]

# The combination of above
data[2, ..., torch.tensor([2, 1, 3]), 2:4, torch.tensor([[1], [2]])]

# Boolean mask (supported for ONNX opset version >= 11)
data[data != 1]

And below is the list of unsupported patterns for RHS indexing.

# Tensor indices that includes negative values.
data[torch.tensor([[1, 2], [2, -3]]), torch.tensor([-2, 3])]

Setter

In code, this type of indexing occurs on the LHS. Export is supported for ONNX opset version >= 11. E.g.:

data = torch.zeros(3, 4)
new_data = torch.arange(4).to(torch.float32)

# LHS indexing is supported in ONNX opset >= 11.
class LHSIndexing(torch.nn.Module):
    def forward(self, data, new_data):
        data[1] = new_data
        return data

out = LHSIndexing()(data, new_data)

data = torch.zeros(3, 4)
new_data = torch.arange(4).to(torch.float32)
torch.onnx.export(LHSIndexing(), (data, new_data), 'inplace_assign.onnx', opset_version=11)

# onnxruntime
import onnxruntime
sess = onnxruntime.InferenceSession('inplace_assign.onnx')
out_ort = sess.run(None, {
    sess.get_inputs()[0].name: torch.zeros(3, 4).numpy(),
    sess.get_inputs()[1].name: new_data.numpy(),
})

assert torch.all(torch.eq(out, torch.tensor(out_ort)))

Below is the list of supported patterns for LHS indexing.

# Scalar indices
data[0, 1] = new_data

# Slice indices
data[:3] = new_data

# Tensor indices
# If more than one tensor are used as indices, only consecutive 1-d tensor indices are supported.
data[torch.tensor([[1, 2], [2, 3]])] = new_data
data[torch.tensor([2, 3]), torch.tensor([1, 2])] = new_data

# Ellipsis
# Not supported to export in script modules
# i.e. torch.onnx.export(torch.jit.script(model)) will fail if model contains this pattern.
# Export is supported under tracing
# i.e. torch.onnx.export(model)
data[...] = new_data

# The combination of above
data[2, ..., torch.tensor([2, 1, 3]), 2:4] += update

# Boolean mask
data[data != 1] = new_data

And below is the list of unsupported patterns for LHS indexing.

# Multiple tensor indices if any has rank >= 2
data[torch.tensor([[1, 2], [2, 3]]), torch.tensor([2, 3])] = new_data

# Multiple tensor indices that are not consecutive
data[torch.tensor([2, 3]), :, torch.tensor([1, 2])] = new_data

# Tensor indices that includes negative values.
data[torch.tensor([1, -2]), torch.tensor([-2, 3])] = new_data

If you are experiencing issues exporting indexing that belongs to the above supported patterns, please double check that you are exporting with the latest opset (opset_version=12).

TorchVision support

All TorchVision models, except for quantized versions, are exportable to ONNX. More details can be found in TorchVision.

Limitations

  • Only tuples, lists and Variables are supported as JIT inputs/outputs. Dictionaries and strings are also accepted but their usage is not recommended. Users need to verify their dict inputs carefully, and keep in mind that dynamic lookups are not available.
  • PyTorch and ONNX backends(Caffe2, ONNX Runtime, etc) often have implementations of operators with some numeric differences. Depending on model structure, these differences may be negligible, but they can also cause major divergences in behavior (especially on untrained models.) We allow Caffe2 to call directly to Torch implementations of operators, to help you smooth over these differences when precision is important, and to also document these differences.

Supported operators

The following operators are supported:

  • BatchNorm
  • ConstantPadNd
  • Conv
  • Dropout
  • Embedding (no optional arguments supported)
  • EmbeddingBag
  • FeatureDropout (training mode not supported)
  • Index
  • MaxPool1d
  • MaxPool2d
  • MaxPool3d
  • RNN
  • abs
  • absolute
  • acos
  • adaptive_avg_pool1d
  • adaptive_avg_pool2d
  • adaptive_avg_pool3d
  • adaptive_max_pool1d
  • adaptive_max_pool2d
  • adaptive_max_pool3d
  • add (nonzero alpha not supported)
  • addmm
  • and
  • arange
  • argmax
  • argmin
  • asin
  • atan
  • avg_pool1d
  • avg_pool2d
  • avg_pool2d
  • avg_pool3d
  • as_strided
  • baddbmm
  • bitshift
  • cat
  • ceil
  • celu
  • clamp
  • clamp_max
  • clamp_min
  • concat
  • copy
  • cos
  • cumsum
  • det
  • dim_arange
  • div
  • dropout
  • einsum
  • elu
  • empty
  • empty_like
  • eq
  • erf
  • exp
  • expand
  • expand_as
  • eye
  • flatten
  • floor
  • floor_divide
  • frobenius_norm
  • full
  • full_like
  • gather
  • ge
  • gelu
  • glu
  • group_norm
  • gt
  • hardswish
  • hardtanh
  • im2col
  • index_copy
  • index_fill
  • index_put
  • index_select
  • instance_norm
  • interpolate
  • isnan
  • KLDivLoss
  • layer_norm
  • le
  • leaky_relu
  • len
  • log
  • log1p
  • log2
  • log_sigmoid
  • log_softmax
  • logdet
  • logsumexp
  • lt
  • masked_fill
  • masked_scatter
  • masked_select
  • max
  • mean
  • min
  • mm
  • mul
  • multinomial
  • narrow
  • ne
  • neg
  • new_empty
  • new_full
  • new_zeros
  • nll_loss
  • nonzero
  • norm
  • ones
  • ones_like
  • or
  • permute
  • pixel_shuffle
  • pow
  • prelu (single weight shared among input channels not supported)
  • prod
  • rand
  • randn
  • randn_like
  • reciprocal
  • reflection_pad
  • relu
  • repeat
  • replication_pad
  • reshape
  • reshape_as
  • round
  • rrelu
  • rsqrt
  • rsub
  • scalar_tensor
  • scatter
  • scatter_add
  • select
  • selu
  • sigmoid
  • sign
  • sin
  • size
  • slice
  • softmax
  • softplus
  • sort
  • split
  • sqrt
  • squeeze
  • stack
  • std
  • sub (nonzero alpha not supported)
  • sum
  • t
  • tan
  • tanh
  • threshold (non-zero threshold/non-zero value not supported)
  • to
  • topk
  • transpose
  • true_divide
  • type_as
  • unbind
  • unfold (experimental support with ATen-Caffe2 integration)
  • unique
  • unsqueeze
  • upsample_nearest1d
  • upsample_nearest2d
  • upsample_nearest3d
  • view
  • weight_norm
  • where
  • zeros
  • zeros_like

The operator set above is sufficient to export the following models:

  • AlexNet
  • DCGAN
  • DenseNet
  • Inception (warning: this model is highly sensitive to changes in operator implementation)
  • ResNet
  • SuperResolution
  • VGG
  • word_language_model

Adding support for operators

Adding export support for operators is an advance usage.

To achieve this, developers need to touch the source code of PyTorch. Please follow the instructions for installing PyTorch from source. If the wanted operator is standardized in ONNX, it should be easy to add support for exporting such operator (adding a symbolic function for the operator). To confirm whether the operator is standardized or not, please check the ONNX operator list.

ATen operators

If the operator is an ATen operator, which means you can find the declaration of the function in torch/csrc/autograd/generated/VariableType.h (available in generated code in PyTorch install dir), you should add the symbolic function in torch/onnx/symbolic_opset<version>.py and follow the instructions listed as below:

  • Define the symbolic function in torch/onnx/symbolic_opset<version>.py, for example torch/onnx/symbolic_opset9.py. Make sure the function has the same name as the ATen operator/function defined in VariableType.h.
  • The first parameter is always the exported ONNX graph. Parameter names must EXACTLY match the names in VariableType.h, because dispatch is done with keyword arguments.
  • Parameter ordering does NOT necessarily match what is in VariableType.h, tensors (inputs) are always first, then non-tensor arguments.
  • In the symbolic function, if the operator is already standardized in ONNX, we only need to create a node to represent the ONNX operator in the graph.
  • If the input argument is a tensor, but ONNX asks for a scalar, we have to explicitly do the conversion. The helper function _scalar can convert a scalar tensor into a python scalar, and _if_scalar_type_as can turn a Python scalar into a PyTorch tensor.

Non-ATen operators

If the operator is a non-ATen operator, the symbolic function has to be added in the corresponding PyTorch Function class. Please read the following instructions:

  • Create a symbolic function named symbolic in the corresponding Function class.
  • The first parameter is always the exported ONNX graph.
  • Parameter names except the first must EXACTLY match the names in forward.
  • The output tuple size must match the outputs of forward.
  • In the symbolic function, if the operator is already standardized in ONNX, we just need to create a node to represent the ONNX operator in the graph.

Symbolic functions should be implemented in Python. All of these functions interact with Python methods which are implemented via C++-Python bindings, but intuitively the interface they provide looks like this:

def operator/symbolic(g, *inputs):
  """
  Modifies Graph (e.g., using "op"), adding the ONNX operations representing
  this PyTorch function, and returning a Value or tuple of Values specifying the
  ONNX outputs whose values correspond to the original PyTorch return values
  of the autograd Function (or None if an output is not supported by ONNX).

  Args:
    g (Graph): graph to write the ONNX representation into
    inputs (Value...): list of values representing the variables which contain
        the inputs for this function
  """

class Value(object):
  """Represents an intermediate tensor value computed in ONNX."""
  def type(self):
    """Returns the Type of the value."""

class Type(object):
  def sizes(self):
    """Returns a tuple of ints representing the shape of a tensor this describes."""

class Graph(object):
  def op(self, opname, *inputs, **attrs):
    """
    Create an ONNX operator 'opname', taking 'args' as inputs
    and attributes 'kwargs' and add it as a node to the current graph,
    returning the value representing the single output of this
    operator (see the `outputs` keyword argument for multi-return
    nodes).

    The set of operators and the inputs/attributes they take
    is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md

    Args:
        opname (string): The ONNX operator name, e.g., `Abs` or `Add`.
        args (Value...): The inputs to the operator; usually provided
            as arguments to the `symbolic` definition.
        kwargs: The attributes of the ONNX operator, with keys named
            according to the following convention: `alpha_f` indicates
            the `alpha` attribute with type `f`.  The valid type specifiers are
            `f` (float), `i` (int), `s` (string) or `t` (Tensor).  An attribute
            specified with type float accepts either a single float, or a
            list of floats (e.g., you would say `dims_i` for a `dims` attribute
            that takes a list of integers).
        outputs (int, optional):  The number of outputs this operator returns;
            by default an operator is assumed to return a single output.
            If `outputs` is greater than one, this functions returns a tuple
            of output `Value`, representing each output of the ONNX operator
            in positional.
    """

The ONNX graph C++ definition is in torch/csrc/jit/ir/ir.h.

Here is an example of handling missing symbolic function for elu operator. We try to export the model and see the error message as below:

UserWarning: ONNX export failed on elu because torch.onnx.symbolic_opset9.elu does not exist
RuntimeError: ONNX export failed: Couldn't export operator elu

The export fails because PyTorch does not support exporting elu operator. We find virtual Tensor elu(const Tensor & input, Scalar alpha, bool inplace) const override; in VariableType.h. This means elu is an ATen operator. We check the ONNX operator list, and confirm that Elu is standardized in ONNX. We add the following lines to symbolic_opset9.py:

def elu(g, input, alpha, inplace=False):
    return g.op("Elu", input, alpha_f=_scalar(alpha))

Now PyTorch is able to export elu operator.

There are more examples in symbolic_opset9.py, symbolic_opset10.py.

The interface for specifying operator definitions is experimental; adventurous users should note that the APIs will probably change in a future interface.

Custom operators

Following this tutorial Extending TorchScript with Custom C++ Operators, you can create and register your own custom ops implementation in PyTorch. Here’s how to export such model to ONNX.:

# Create custom symbolic function
from torch.onnx.symbolic_helper import parse_args
@parse_args('v', 'v', 'f', 'i')
def symbolic_foo_forward(g, input1, input2, attr1, attr2):
    return g.op("Foo", input1, input2, attr1_f=attr1, attr2_i=attr2)

# Register custom symbolic function
from torch.onnx import register_custom_op_symbolic
register_custom_op_symbolic('custom_ops::foo_forward', symbolic_foo_forward, 9)

class FooModel(torch.nn.Module):
    def __init__(self, attr1, attr2):
        super(FooModule, self).__init__()
        self.attr1 = attr1
        self.attr2 = attr2

    def forward(self, input1, input2):
        # Calling custom op
        return torch.ops.custom_ops.foo_forward(input1, input2, self.attr1, self.attr2)

model = FooModel(attr1, attr2)
torch.onnx.export(model, (dummy_input1, dummy_input2), 'model.onnx', custom_opsets={"custom_domain": 2})

Depending on the custom operator, you can export it as one or a combination of existing ONNX ops. You can also export it as a custom op in ONNX as well. In that case, you can specify the custom domain and version (custom opset) using the custom_opsets dictionary at export. If not explicitly specified, the custom opset version is set to 1 by default. Using custom ONNX ops, you will need to extend the backend of your choice with matching custom ops implementation, e.g. Caffe2 custom ops, ONNX Runtime custom ops.

Operator Export Type

Exporting models with unsupported ONNX operators can be achieved using the operator_export_type flag in export API. This flag is useful when users try to export ATen and non-ATen operators that are not registered and supported in ONNX.

ONNX

This mode is used to export all operators as regular ONNX operators. This is the default operator_export_type mode.

Example torch ir graph:

  graph(%0 : Float(2, 3, 4, strides=[12, 4, 1])):
    %3 : Float(2, 3, 4, strides=[12, 4, 1]) = aten:exp(%0)
    %4 : Float(2, 3, 4, strides=[12, 4, 1]) = aten:div(%0, %3)
    return (%4)

Is exported as:

  graph(%0 : Float(2, 3, 4, strides=[12, 4, 1])):
    %1 : Float(2, 3, 4, strides=[12, 4, 1]) = onnx:Exp(%0)
    %2 : Float(2, 3, 4, strides=[12, 4, 1]) = onnx:Div(%0, %1)
    return (%2)

ONNX_ATEN

This mode is used to export all operators as ATen ops, and avoid conversion to ONNX.

Example torch ir graph:

  graph(%0 : Float(2, 3, 4, strides=[12, 4, 1])):
    %3 : Float(2, 3, 4, strides=[12, 4, 1]) = aten::exp(%0)
    %4 : Float(2, 3, 4, strides=[12, 4, 1]) = aten::div(%0, %3)
    return (%4)

Is exported as:

  graph(%0 : Float(2, 3, 4, strides=[12, 4, 1])):
    %1 : Float(2, 3, 4, strides=[12, 4, 1]) = aten::ATen[operator="exp"](%0)
    %2 : Float(2, 3, 4, strides=[12, 4, 1]) = aten::ATen[operator="div"](%0, %1)
    return (%2)

ONNX_ATEN_FALLBACK

To fallback on unsupported ATen operators in ONNX. Supported operators are exported to ONNX regularly. In the following example, aten::triu is not supported in ONNX. Exporter falls back on this operator.

Example torch ir graph:

  graph(%0 : Float):
    %3 : int = prim::Constant[value=0]()
    %4 : Float = aten::triu(%0, %3) # unsupported op
    %5 : Float = aten::mul(%4, %0) # registered op
    return (%5)

is exported as:

  graph(%0 : Float):
    %1 : Long() = onnx::Constant[value={0}]()
    %2 : Float = aten::ATen[operator="triu"](%0, %1) # unsupported op
    %3 : Float = onnx::Mul(%2, %0) # registered op
    return (%3)

RAW

To export a raw ir.

Example torch ir graph:

  graph(%x.1 : Float(1, strides=[1])):
    %1 : Tensor = aten::exp(%x.1)
    %2 : Tensor = aten::div(%x.1, %1)
    %y.1 : Tensor[] = prim::ListConstruct(%2)
    return (%y.1)

is exported as:

  graph(%x.1 : Float(1, strides=[1])):
    %1 : Tensor = aten::exp(%x.1)
    %2 : Tensor = aten::div(%x.1, %1)
    %y.1 : Tensor[] = prim::ListConstruct(%2)
    return (%y.1)

ONNX_FALLTHROUGH

This mode can be used to export any operator (ATen or non-ATen) that is not registered and supported in ONNX. Exported falls through and exports the operator as is, as custom op. Exporting custom operators enables users to register and implement the operator as part of their runtime backend.

Example torch ir graph:

  graph(%0 : Float(2, 3, 4, strides=[12, 4, 1]),
        %1 : Float(2, 3, 4, strides=[12, 4, 1])):
    %6 : Float(2, 3, 4, strides=[12, 4, 1]) = foo_namespace::bar(%0, %1) # custom op
    %7 : Float(2, 3, 4, strides=[12, 4, 1]) = aten::div(%6, %0) # registered op
    return (%7))

is exported as:

  graph(%0 : Float(2, 3, 4, strides=[12, 4, 1]),
        %1 : Float(2, 3, 4, strides=[12, 4, 1])):
    %2 : Float(2, 3, 4, strides=[12, 4, 1]) = foo_namespace::bar(%0, %1) # custom op
    %3 : Float(2, 3, 4, strides=[12, 4, 1]) = onnx::Div(%2, %0) # registered op
    return (%3

Frequently Asked Questions

Q: I have exported my lstm model, but its input size seems to be fixed?

The tracer records the example inputs shape in the graph. In case the model should accept inputs of dynamic shape, you can utilize the parameter dynamic_axes in export api.

layer_count = 4

model = nn.LSTM(10, 20, num_layers=layer_count, bidirectional=True)
model.eval()

with torch.no_grad():
    input = torch.randn(5, 3, 10)
    h0 = torch.randn(layer_count * 2, 3, 20)
    c0 = torch.randn(layer_count * 2, 3, 20)
    output, (hn, cn) = model(input, (h0, c0))

    # default export
    torch.onnx.export(model, (input, (h0, c0)), 'lstm.onnx')
    onnx_model = onnx.load('lstm.onnx')
    # input shape [5, 3, 10]
    print(onnx_model.graph.input[0])

    # export with `dynamic_axes`
    torch.onnx.export(model, (input, (h0, c0)), 'lstm.onnx',
                    input_names=['input', 'h0', 'c0'],
                    output_names=['output', 'hn', 'cn'],
                    dynamic_axes={'input': {0: 'sequence'}, 'output': {0: 'sequence'}})
    onnx_model = onnx.load('lstm.onnx')
    # input shape ['sequence', 3, 10]
    print(onnx_model.graph.input[0])

Q: How to export models with loops in it?

Please checkout Tracing vs Scripting.

Q: Does ONNX support implicit scalar datatype casting?

No, but the exporter will try to handle that part. Scalars are converted to constant tensors in ONNX. The exporter will try to figure out the right datatype for scalars. However for cases that it failed to do so, you will need to manually provide the datatype information. This often happens with scripted models, where the datatypes are not recorded. We are trying to improve the datatype propagation in the exporter such that manual changes are not required in the future.

class ImplicitCastType(torch.jit.ScriptModule):
    @torch.jit.script_method
    def forward(self, x):
        # Exporter knows x is float32, will export '2' as float32 as well.
        y = x + 2
        # Without type propagation, exporter doesn't know the datatype of y.
        # Thus '3' is exported as int64 by default.
        return y + 3
        # The following will export correctly.
        # return y + torch.tensor([3], dtype=torch.float32)

x = torch.tensor([1.0], dtype=torch.float32)
torch.onnx.export(ImplicitCastType(), x, 'models/implicit_cast.onnx',
                  example_outputs=ImplicitCastType()(x))

Q: Is tensor in-place indexed assignment like data[index] = new_data supported?

Yes, this is supported for ONNX opset version >= 11. Please checkout Indexing.

Q: Is tensor list exportable to ONNX?

Yes, this is supported now for ONNX opset version >= 11. ONNX introduced the concept of Sequence in opset 11. Similar to list, Sequence is a data type that contains arbitrary number of Tensors. Associated operators are also introduced in ONNX, such as SequenceInsert, SequenceAt, etc. However, in-place list append within loops is not exportable to ONNX. To implement this, please use inplace add operator. E.g.:

class ListLoopModel(torch.nn.Module):
    def forward(self, x):
        res = []
        res1 = []
        arr = x.split(2, 0)
        res2 = torch.zeros(3, 4, dtype=torch.long)
        for i in range(len(arr)):
            res += [arr[i].sum(0, False)]
            res1 += [arr[-1 - i].sum(0, False)]
            res2 += 1
        return torch.stack(res), torch.stack(res1), res2

model = torch.jit.script(ListLoopModel())
inputs = torch.randn(16)

out = model(inputs)
torch.onnx.export(model, (inputs, ), 'loop_and_list.onnx', opset_version=11, example_outputs=out)

# onnxruntime
import onnxruntime
sess = onnxruntime.InferenceSession('loop_and_list.onnx')
out_ort = sess.run(None, {
    sess.get_inputs()[0].name: inputs.numpy(),
})

assert [torch.allclose(o, torch.tensor(o_ort)) for o, o_ort in zip(out, out_ort)]

Use external data format

use_external_data_format argument in export API enables export of models in ONNX external data format. With this option enabled, the exporter stores some model parameters in external binary files, rather than the ONNX file itself. These external binary files are stored in the same location as the ONNX file. Argument ‘f’ must be a string specifying the location of the model.

model = torchvision.models.mobilenet_v2(pretrained=True)
input = torch.randn(2, 3, 224, 224, requires_grad=True)
torch.onnx.export(model, (input, ), './large_model.onnx', use_external_data_format=True)

This argument enables export of large models to ONNX. Models larger than 2GB cannot be exported in one file because of the protobuf size limit. Users should set use_external_data_format to True to successfully export such models.

Training

Training argument in export API allows users to export models in a training-friendly mode. TrainingMode.TRAINING exports model in a training-friendly mode that avoids certain model optimizations which might interfere with model parameter training. TrainingMode.PRESERVE exports the model in inference mode if model.training is False. Otherwise, it exports the model in a training-friendly mode. The default mode for this argument is TrainingMode.EVAL which exports the model in inference mode.

Functions

torch.onnx.export(model, args, f, export_params=True, verbose=False, training=<TrainingMode.EVAL: 0>, input_names=None, output_names=None, aten=False, export_raw_ir=False, operator_export_type=None, opset_version=None, _retain_param_name=True, do_constant_folding=True, example_outputs=None, strip_doc_string=True, dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None, enable_onnx_checker=True, use_external_data_format=False) [source]

Export a model into ONNX format. This exporter runs your model once in order to get a trace of its execution to be exported; at the moment, it supports a limited set of dynamic models (e.g., RNNs.)

Parameters
  • model (torch.nn.Module) – the model to be exported.
  • args (tuple of arguments or torch.Tensor, a dictionary consisting of named arguments (optional)) –

    a dictionary to specify the input to the corresponding named parameter: - KEY: str, named parameter - VALUE: corresponding input args can be structured either as:

    1. ONLY A TUPLE OF ARGUMENTS or torch.Tensor:

      ‘’args = (x, y, z)’'
      

    The inputs to the model, e.g., such that model(*args) is a valid invocation of the model. Any non-Tensor arguments will be hard-coded into the exported model; any Tensor arguments will become inputs of the exported model, in the order they occur in args. If args is a Tensor, this is equivalent to having called it with a 1-ary tuple of that Tensor.

    1. A TUPLE OF ARGUEMENTS WITH A DICTIONARY OF NAMED PARAMETERS:

      ‘’args = (x,
              {
              ‘y’: input_y,
              ‘z’: input_z
              }) ‘’
      

    The inputs to the model are structured as a tuple consisting of non-keyword arguments and the last value of this tuple being a dictionary consisting of named parameters and the corresponding inputs as key-value pairs. If certain named argument is not present in the dictionary, it is assigned the default value, or None if default value is not provided.

    Cases in which an dictionary input is the last input of the args tuple would cause a conflict when a dictionary of named parameters is used. The model below provides such an example.

    class Model(torch.nn.Module):
    def forward(self, k, x):

    … return x

    m = Model() k = torch.randn(2, 3) x = {torch.tensor(1.): torch.randn(2, 3)}

    In the previous iteration, the call to export API would look like

    torch.onnx.export(model, (k, x), ‘test.onnx’)

    This would work as intended. However, the export function would now assume that the ‘x’ input is intended to represent the optional dictionary consisting of named arguments. In order to prevent this from being an issue a constraint is placed to provide an empty dictionary as the last input in the tuple args in such cases. The new call would look like this.

    torch.onnx.export(model, (k, x, {}), ‘test.onnx’)

  • f – a file-like object (has to implement fileno that returns a file descriptor) or a string containing a file name. A binary Protobuf will be written to this file.
  • export_params (bool, default True) – if specified, all parameters will be exported. Set this to False if you want to export an untrained model. In this case, the exported model will first take all of its parameters as arguments, the ordering as specified by model.state_dict().values()
  • verbose (bool, default False) – if specified, we will print out a debug description of the trace being exported.
  • training (enum, default TrainingMode.EVAL) – TrainingMode.EVAL: export the model in inference mode. TrainingMode.PRESERVE: export the model in inference mode if model.training is False and to a training friendly mode if model.training is True. TrainingMode.TRAINING: export the model in a training friendly mode.
  • input_names (list of strings, default empty list) – names to assign to the input nodes of the graph, in order
  • output_names (list of strings, default empty list) – names to assign to the output nodes of the graph, in order
  • aten (bool, default False) – [DEPRECATED. use operator_export_type] export the model in aten mode. If using aten mode, all the ops original exported by the functions in symbolic_opset<version>.py are exported as ATen ops.
  • export_raw_ir (bool, default False) – [DEPRECATED. use operator_export_type] export the internal IR directly instead of converting it to ONNX ops.
  • operator_export_type (enum, default OperatorExportTypes.ONNX) –

    OperatorExportTypes.ONNX: All ops are exported as regular ONNX ops (with ONNX namespace). OperatorExportTypes.ONNX_ATEN: All ops are exported as ATen ops (with aten namespace). OperatorExportTypes.ONNX_ATEN_FALLBACK: If an ATen op is not supported in ONNX or its symbolic is missing, fall back on ATen op. Registered ops are exported to ONNX regularly. Example graph:

    graph(%0 : Float)::
      %3 : int = prim::Constant[value=0]()
      %4 : Float = aten::triu(%0, %3) # missing op
      %5 : Float = aten::mul(%4, %0) # registered op
      return (%5)
    

    is exported as:

    graph(%0 : Float)::
      %1 : Long() = onnx::Constant[value={0}]()
      %2 : Float = aten::ATen[operator="triu"](%0, %1)  # missing op
      %3 : Float = onnx::Mul(%2, %0) # registered op
      return (%3)
    

    In the above example, aten::triu is not supported in ONNX, hence exporter falls back on this op. OperatorExportTypes.RAW: Export raw ir. OperatorExportTypes.ONNX_FALLTHROUGH: If an op is not supported in ONNX, fall through and export the operator as is, as a custom ONNX op. Using this mode, the op can be exported and implemented by the user for their runtime backend. Example graph:

    graph(%x.1 : Long(1, strides=[1]))::
      %1 : None = prim::Constant()
      %2 : Tensor = aten::sum(%x.1, %1)
      %y.1 : Tensor[] = prim::ListConstruct(%2)
      return (%y.1)
    

    is exported as:

    graph(%x.1 : Long(1, strides=[1]))::
      %1 : Tensor = onnx::ReduceSum[keepdims=0](%x.1)
      %y.1 : Long() = prim::ListConstruct(%1)
      return (%y.1)
    

    In the above example, prim::ListConstruct is not supported, hence exporter falls through.

  • opset_version (int, default is 9) – by default we export the model to the opset version of the onnx submodule. Since ONNX’s latest opset may evolve before next stable release, by default we export to one stable opset version. Right now, supported stable opset version is 9. The opset_version must be _onnx_main_opset or in _onnx_stable_opsets which are defined in torch/onnx/symbolic_helper.py
  • do_constant_folding (bool, default False) – If True, the constant-folding optimization is applied to the model during export. Constant-folding optimization will replace some of the ops that have all constant inputs, with pre-computed constant nodes.
  • example_outputs (tuple of Tensors, default None) – Model’s example outputs being exported. example_outputs must be provided when exporting a ScriptModule or TorchScript Function.
  • strip_doc_string (bool, default True) – if True, strips the field “doc_string” from the exported model, which information about the stack trace.
  • dynamic_axes (dict<string, dict<python:int, string>> or dict<string, list(int)>, default empty dict) –

    a dictionary to specify dynamic axes of input/output, such that: - KEY: input and/or output names - VALUE: index of dynamic axes for given key and potentially the name to be used for exported dynamic axes. In general the value is defined according to one of the following ways or a combination of both: (1). A list of integers specifying the dynamic axes of provided input. In this scenario automated names will be generated and applied to dynamic axes of provided input/output during export. OR (2). An inner dictionary that specifies a mapping FROM the index of dynamic axis in corresponding input/output TO the name that is desired to be applied on such axis of such input/output during export.

    Example. if we have the following shape for inputs and outputs:

    shape(input_1) = ('b', 3, 'w', 'h')
    and shape(input_2) = ('b', 4)
    and shape(output)  = ('b', 'd', 5)
    

    Then dynamic axes can be defined either as:

    1. ONLY INDICES:

      ``dynamic_axes = {'input_1':[0, 2, 3],
                        'input_2':[0],
                        'output':[0, 1]}``
      where automatic names will be generated for exported dynamic axes
      
    2. INDICES WITH CORRESPONDING NAMES:

      ``dynamic_axes = {'input_1':{0:'batch',
                                   1:'width',
                                   2:'height'},
                        'input_2':{0:'batch'},
                        'output':{0:'batch',
                                  1:'detections'}}``
      where provided names will be applied to exported dynamic axes
      
    3. MIXED MODE OF (1) and (2):

      ``dynamic_axes = {'input_1':[0, 2, 3],
                        'input_2':{0:'batch'},
                        'output':[0,1]}``
      
  • keep_initializers_as_inputs (bool, default None) –

    If True, all the initializers (typically corresponding to parameters) in the exported graph will also be added as inputs to the graph. If False, then initializers are not added as inputs to the graph, and only the non-parameter inputs are added as inputs.

    This may allow for better optimizations (such as constant folding etc.) by backends/runtimes that execute these graphs. If unspecified (default None), then the behavior is chosen automatically as follows. If operator_export_type is OperatorExportTypes.ONNX, the behavior is equivalent to setting this argument to False. For other values of operator_export_type, the behavior is equivalent to setting this argument to True. Note that for ONNX opset version < 9, initializers MUST be part of graph inputs. Therefore, if opset_version argument is set to a 8 or lower, this argument will be ignored.

  • custom_opsets (dict<string, int>, default empty dict) – A dictionary to indicate custom opset domain and version at export. If model contains a custom opset, it is optional to specify the domain and opset version in the dictionary: - KEY: opset domain name - VALUE: opset version If the custom opset is not provided in this dictionary, opset version is set to 1 by default.
  • enable_onnx_checker (bool, default True) – If True the onnx model checker will be run as part of the export, to ensure the exported model is a valid ONNX model.
  • external_data_format (bool, default False) – If True, then the model is exported in ONNX external data format, in which case some of the model parameters are stored in external binary files and not in the ONNX model file itself. See link for format details: https://github.com/onnx/onnx/blob/8b3f7e2e7a0f2aba0e629e23d89f07c7fc0e6a5e/onnx/onnx.proto#L423 Also, in this case, argument ‘f’ must be a string specifying the location of the model. The external binary files will be stored in the same location specified by the model location ‘f’. If False, then the model is stored in regular format, i.e. model and parameters are all in one file. This argument is ignored for all export types other than ONNX.
torch.onnx.export_to_pretty_string(*args, **kwargs) [source]
torch.onnx.register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version) [source]
torch.onnx.operators.shape_as_tensor(x) [source]
torch.onnx.select_model_mode_for_export(model, mode) [source]

A context manager to temporarily set the training mode of ‘model’ to ‘mode’, resetting it when we exit the with-block. A no-op if mode is None.

In version 1.6 changed to this from set_training

torch.onnx.is_in_onnx_export() [source]

Check whether it’s in the middle of the ONNX export. This function returns True in the middle of torch.onnx.export(). torch.onnx.export should be executed with single thread.

© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.8.0/onnx.html