torch.onnx¶
Example: End-to-end AlexNet from PyTorch to Caffe2¶
Here is a simple script which exports a pretrained AlexNet as defined in
torchvision into ONNX. It runs a single round of inference and then
saves the resulting traced model to alexnet.onnx
:
import torch
import torchvision
dummy_input = torch.randn(10, 3, 224, 224, device='cuda')
model = torchvision.models.alexnet(pretrained=True).cuda()
# Providing input and output names sets the display names for values
# within the model's graph. Setting these does not change the semantics
# of the graph; it is only for readability.
#
# The inputs to the network consist of the flat list of inputs (i.e.
# the values you would pass to the forward() method) followed by the
# flat list of parameters. You can partially specify names, i.e. provide
# a list here shorter than the number of inputs to the model, and we will
# only set that subset of names, starting from the beginning.
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]
torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)
The resulting alexnet.onnx
is a binary protobuf file which contains both
the network structure and parameters of the model you exported
(in this case, AlexNet). The keyword argument verbose=True
causes the
exporter to print out a human-readable representation of the network:
# These are the inputs and parameters to the network, which have taken on
# the names we specified earlier.
graph(%actual_input_1 : Float(10, 3, 224, 224)
%learned_0 : Float(64, 3, 11, 11)
%learned_1 : Float(64)
%learned_2 : Float(192, 64, 5, 5)
%learned_3 : Float(192)
# ---- omitted for brevity ----
%learned_14 : Float(1000, 4096)
%learned_15 : Float(1000)) {
# Every statement consists of some output tensors (and their types),
# the operator to be run (with its attributes, e.g., kernels, strides,
# etc.), its input tensors (%actual_input_1, %learned_0, %learned_1)
%17 : Float(10, 64, 55, 55) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[11, 11], pads=[2, 2, 2, 2], strides=[4, 4]](%actual_input_1, %learned_0, %learned_1), scope: AlexNet/Sequential[features]/Conv2d[0]
%18 : Float(10, 64, 55, 55) = onnx::Relu(%17), scope: AlexNet/Sequential[features]/ReLU[1]
%19 : Float(10, 64, 27, 27) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%18), scope: AlexNet/Sequential[features]/MaxPool2d[2]
# ---- omitted for brevity ----
%29 : Float(10, 256, 6, 6) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%28), scope: AlexNet/Sequential[features]/MaxPool2d[12]
# Dynamic means that the shape is not known. This may be because of a
# limitation of our implementation (which we would like to fix in a
# future release) or shapes which are truly dynamic.
%30 : Dynamic = onnx::Shape(%29), scope: AlexNet
%31 : Dynamic = onnx::Slice[axes=[0], ends=[1], starts=[0]](%30), scope: AlexNet
%32 : Long() = onnx::Squeeze[axes=[0]](%31), scope: AlexNet
%33 : Long() = onnx::Constant[value={9216}](), scope: AlexNet
# ---- omitted for brevity ----
%output1 : Float(10, 1000) = onnx::Gemm[alpha=1, beta=1, broadcast=1, transB=1](%45, %learned_14, %learned_15), scope: AlexNet/Sequential[classifier]/Linear[6]
return (%output1);
}
You can also verify the protobuf using the onnx library.
You can install onnx
with conda:
conda install -c conda-forge onnx
Then, you can run:
import onnx
# Load the ONNX model
model = onnx.load("alexnet.onnx")
# Check that the IR is well formed
onnx.checker.check_model(model)
# Print a human readable representation of the graph
onnx.helper.printable_graph(model.graph)
To run the exported script with caffe2, you will need to install caffe2: If you don’t have one already, Please follow the install instructions.
Once these are installed, you can use the backend for Caffe2:
# ...continuing from above
import caffe2.python.onnx.backend as backend
import numpy as np
rep = backend.prepare(model, device="CUDA:0") # or "CPU"
# For the Caffe2 backend:
# rep.predict_net is the Caffe2 protobuf for the network
# rep.workspace is the Caffe2 workspace for the network
# (see the class caffe2.python.onnx.backend.Workspace)
outputs = rep.run(np.random.randn(10, 3, 224, 224).astype(np.float32))
# To run networks with more than one input, pass a tuple
# rather than a single numpy ndarray.
print(outputs[0])
In the future, there will be backends for other frameworks as well.
Limitations¶
The ONNX exporter is a trace-based exporter, which means that it operates by executing your model once, and exporting the operators which were actually run during this run. This means that if your model is dynamic, e.g., changes behavior depending on input data, the export won’t be accurate. Similarly, a trace is likely to be valid only for a specific input size (which is one reason why we require explicit inputs on tracing.) We recommend examining the model trace and making sure the traced operators look reasonable.
PyTorch and Caffe2 often have implementations of operators with some numeric differences. Depending on model structure, these differences may be negligible, but they can also cause major divergences in behavior (especially on untrained models.) In a future release, we plan to allow Caffe2 to call directly to Torch implementations of operators, to help you smooth over these differences when precision is important, and to also document these differences.
Supported operators¶
The following operators are supported:
add (nonzero alpha not supported)
sub (nonzero alpha not supported)
mul
div
cat
mm
addmm
neg
sqrt
tanh
sigmoid
mean
sum
prod
t
expand (only when used before a broadcasting ONNX operator; e.g., add)
transpose
view
split
squeeze
prelu (single weight shared among input channels not supported)
threshold (non-zero threshold/non-zero value not supported)
leaky_relu
glu
softmax (only dim=-1 supported)
avg_pool2d (ceil_mode not supported)
log_softmax
unfold (experimental support with ATen-Caffe2 integration)
elu
concat
abs
index_select
pow
clamp
max
min
eq
gt
lt
ge
le
exp
sin
cos
tan
asin
acos
atan
permute
Conv
BatchNorm
MaxPool1d (ceil_mode not supported)
MaxPool2d (ceil_mode not supported)
MaxPool3d (ceil_mode not supported)
Embedding (no optional arguments supported)
RNN
ConstantPadNd
Dropout
FeatureDropout (training mode not supported)
Index (constant integer and tuple indices supported)
The operator set above is sufficient to export the following models:
AlexNet
DCGAN
DenseNet
Inception (warning: this model is highly sensitive to changes in operator implementation)
ResNet
SuperResolution
VGG
Adding export support for operators is an advance usage. To achieve this, developers need to touch the source code of PyTorch. Please follow the instructions for installing PyTorch from source. If the wanted operator is standardized in ONNX, it should be easy to add support for exporting such operator (adding a symbolic function for the operator). To confirm whether the operator is standardized or not, please check the ONNX operator list.
If the operator is an ATen operator, which means you can find the declaration
of the function in torch/csrc/autograd/generated/VariableType.h
(available in generated code in PyTorch install dir), you should add the symbolic
function in torch/onnx/symbolic.py
and follow the instructions listed as below:
Define the symbolic function in torch/onnx/symbolic.py. Make sure the function has the same name as the ATen operator/function defined in
VariableType.h
.The first parameter is always the exported ONNX graph. Parameter names must EXACTLY match the names in
VariableType.h
, because dispatch is done with keyword arguments.Parameter ordering does NOT necessarily match what is in
VariableType.h
, tensors (inputs) are always first, then non-tensor arguments.In the symbolic function, if the operator is already standardized in ONNX, we only need to create a node to represent the ONNX operator in the graph.
If the input argument is a tensor, but ONNX asks for a scalar, we have to explicitly do the conversion. The helper function
_scalar
can convert a scalar tensor into a python scalar, and_if_scalar_type_as
can turn a Python scalar into a PyTorch tensor.
If the operator is a non-ATen operator, the symbolic function has to be added in the corresponding PyTorch Function class. Please read the following instructions:
Create a symbolic function named
symbolic
in the corresponding Function class.The first parameter is always the exported ONNX graph.
Parameter names except the first must EXACTLY match the names in
forward
.The output tuple size must match the outputs of
forward
.In the symbolic function, if the operator is already standardized in ONNX, we just need to create a node to represent the ONNX operator in the graph.
Symbolic functions should be implemented in Python. All of these functions interact with Python methods which are implemented via C++-Python bindings, but intuitively the interface they provide looks like this:
def operator/symbolic(g, *inputs):
"""
Modifies Graph (e.g., using "op"), adding the ONNX operations representing
this PyTorch function, and returning a Value or tuple of Values specifying the
ONNX outputs whose values correspond to the original PyTorch return values
of the autograd Function (or None if an output is not supported by ONNX).
Arguments:
g (Graph): graph to write the ONNX representation into
inputs (Value...): list of values representing the variables which contain
the inputs for this function
"""
class Value(object):
"""Represents an intermediate tensor value computed in ONNX."""
def type(self):
"""Returns the Type of the value."""
class Type(object):
def sizes(self):
"""Returns a tuple of ints representing the shape of a tensor this describes."""
class Graph(object):
def op(self, opname, *inputs, **attrs):
"""
Create an ONNX operator 'opname', taking 'args' as inputs
and attributes 'kwargs' and add it as a node to the current graph,
returning the value representing the single output of this
operator (see the `outputs` keyword argument for multi-return
nodes).
The set of operators and the inputs/attributes they take
is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md
Arguments:
opname (string): The ONNX operator name, e.g., `Abs` or `Add`.
args (Value...): The inputs to the operator; usually provided
as arguments to the `symbolic` definition.
kwargs: The attributes of the ONNX operator, with keys named
according to the following convention: `alpha_f` indicates
the `alpha` attribute with type `f`. The valid type specifiers are
`f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute
specified with type float accepts either a single float, or a
list of floats (e.g., you would say `dims_i` for a `dims` attribute
that takes a list of integers).
outputs (int, optional): The number of outputs this operator returns;
by default an operator is assumed to return a single output.
If `outputs` is greater than one, this functions returns a tuple
of output `Value`, representing each output of the ONNX operator
in positional.
"""
The ONNX graph C++ definition is in torch/csrc/jit/ir.h
.
Here is an example of handling missing symbolic function for elu
operator.
We try to export the model and see the error message as below:
UserWarning: ONNX export failed on elu because torch.onnx.symbolic.elu does not exist
RuntimeError: ONNX export failed: Couldn't export operator elu
The export fails because PyTorch does not support exporting elu
operator.
We find virtual Tensor elu(const Tensor & input, Scalar alpha, bool inplace) const override;
in VariableType.h
. This means elu
is an ATen operator.
We check the ONNX operator list,
and confirm that Elu
is standardized in ONNX.
We add the following lines to symbolic.py
:
def elu(g, input, alpha, inplace=False):
return g.op("Elu", input, alpha_f=_scalar(alpha))
Now PyTorch is able to export elu
operator.
There are more examples in symbolic.py, tensor.py, padding.py.
The interface for specifying operator definitions is experimental; adventurous users should note that the APIs will probably change in a future interface.
Functions¶
-
torch.onnx.
export
(model, args, f, export_params=True, verbose=False, training=<TrainingMode.EVAL: 0>, input_names=None, output_names=None, aten=False, export_raw_ir=False, operator_export_type=None, opset_version=None, _retain_param_name=True, do_constant_folding=True, example_outputs=None, strip_doc_string=True, dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None, enable_onnx_checker=True, use_external_data_format=False)[source]¶ Export a model into ONNX format. This exporter runs your model once in order to get a trace of its execution to be exported; at the moment, it supports a limited set of dynamic models (e.g., RNNs.)
- Parameters
model (torch.nn.Module) – the model to be exported.
args (tuple of arguments or torch.Tensor, a dictionary consisting of named arguments (optional)) –
a dictionary to specify the input to the corresponding named parameter: - KEY: str, named parameter - VALUE: corresponding input args can be structured either as:
ONLY A TUPLE OF ARGUMENTS or torch.Tensor:
‘’args = (x, y, z)’'
The inputs to the model, e.g., such that
model(*args)
is a valid invocation of the model. Any non-Tensor arguments will be hard-coded into the exported model; any Tensor arguments will become inputs of the exported model, in the order they occur in args. If args is a Tensor, this is equivalent to having called it with a 1-ary tuple of that Tensor.A TUPLE OF ARGUEMENTS WITH A DICTIONARY OF NAMED PARAMETERS:
‘’args = (x, { ‘y’: input_y, ‘z’: input_z }) ‘’
The inputs to the model are structured as a tuple consisting of non-keyword arguments and the last value of this tuple being a dictionary consisting of named parameters and the corresponding inputs as key-value pairs. If certain named argument is not present in the dictionary, it is assigned the default value, or None if default value is not provided.
Cases in which an dictionary input is the last input of the args tuple would cause a conflict when a dictionary of named parameters is used. The model below provides such an example.
- class Model(torch.nn.Module):
- def forward(self, k, x):
… return x
m = Model() k = torch.randn(2, 3) x = {torch.tensor(1.): torch.randn(2, 3)}
In the previous iteration, the call to export API would look like
torch.onnx.export(model, (k, x), ‘test.onnx’)
This would work as intended. However, the export function would now assume that the ‘x’ input is intended to represent the optional dictionary consisting of named arguments. In order to prevent this from being an issue a constraint is placed to provide an empty dictionary as the last input in the tuple args in such cases. The new call would look like this.
torch.onnx.export(model, (k, x, {}), ‘test.onnx’)
f – a file-like object (has to implement fileno that returns a file descriptor) or a string containing a file name. A binary Protobuf will be written to this file.
export_params (bool, default True) – if specified, all parameters will be exported. Set this to False if you want to export an untrained model. In this case, the exported model will first take all of its parameters as arguments, the ordering as specified by
model.state_dict().values()
verbose (bool, default False) – if specified, we will print out a debug description of the trace being exported.
training (enum, default TrainingMode.EVAL) – TrainingMode.EVAL: export the model in inference mode. TrainingMode.PRESERVE: export the model in inference mode if model.training is False and to a training friendly mode if model.training is True. TrainingMode.TRAINING: export the model in a training friendly mode.
input_names (list of strings, default empty list) – names to assign to the input nodes of the graph, in order
output_names (list of strings, default empty list) – names to assign to the output nodes of the graph, in order
aten (bool, default False) – [DEPRECATED. use operator_export_type] export the model in aten mode. If using aten mode, all the ops original exported by the functions in symbolic_opset<version>.py are exported as ATen ops.
export_raw_ir (bool, default False) – [DEPRECATED. use operator_export_type] export the internal IR directly instead of converting it to ONNX ops.
operator_export_type (enum, default OperatorExportTypes.ONNX) –
OperatorExportTypes.ONNX: All ops are exported as regular ONNX ops (with ONNX namespace). OperatorExportTypes.ONNX_ATEN: All ops are exported as ATen ops (with aten namespace). OperatorExportTypes.ONNX_ATEN_FALLBACK: If an ATen op is not supported in ONNX or its symbolic is missing, fall back on ATen op. Registered ops are exported to ONNX regularly. Example graph:
graph(%0 : Float):: %3 : int = prim::Constant[value=0]() %4 : Float = aten::triu(%0, %3) # missing op %5 : Float = aten::mul(%4, %0) # registered op return (%5)
is exported as:
graph(%0 : Float):: %1 : Long() = onnx::Constant[value={0}]() %2 : Float = aten::ATen[operator="triu"](%0, %1) # missing op %3 : Float = onnx::Mul(%2, %0) # registered op return (%3)
In the above example, aten::triu is not supported in ONNX, hence exporter falls back on this op. OperatorExportTypes.RAW: Export raw ir. OperatorExportTypes.ONNX_FALLTHROUGH: If an op is not supported in ONNX, fall through and export the operator as is, as a custom ONNX op. Using this mode, the op can be exported and implemented by the user for their runtime backend. Example graph:
graph(%x.1 : Long(1, strides=[1])):: %1 : None = prim::Constant() %2 : Tensor = aten::sum(%x.1, %1) %y.1 : Tensor[] = prim::ListConstruct(%2) return (%y.1)
is exported as:
graph(%x.1 : Long(1, strides=[1])):: %1 : Tensor = onnx::ReduceSum[keepdims=0](%x.1) %y.1 : Long() = prim::ListConstruct(%1) return (%y.1)
In the above example, prim::ListConstruct is not supported, hence exporter falls through.
opset_version (int, default is 9) – by default we export the model to the opset version of the onnx submodule. Since ONNX’s latest opset may evolve before next stable release, by default we export to one stable opset version. Right now, supported stable opset version is 9. The opset_version must be _onnx_main_opset or in _onnx_stable_opsets which are defined in torch/onnx/symbolic_helper.py
do_constant_folding (bool, default False) – If True, the constant-folding optimization is applied to the model during export. Constant-folding optimization will replace some of the ops that have all constant inputs, with pre-computed constant nodes.
example_outputs (tuple of Tensors, default None) – Model’s example outputs being exported. example_outputs must be provided when exporting a ScriptModule or TorchScript Function.
strip_doc_string (bool, default True) – if True, strips the field “doc_string” from the exported model, which information about the stack trace.
dynamic_axes (dict<string, dict<python:int, string>> or dict<string, list(int)>, default empty dict) –
a dictionary to specify dynamic axes of input/output, such that: - KEY: input and/or output names - VALUE: index of dynamic axes for given key and potentially the name to be used for exported dynamic axes. In general the value is defined according to one of the following ways or a combination of both: (1). A list of integers specifying the dynamic axes of provided input. In this scenario automated names will be generated and applied to dynamic axes of provided input/output during export. OR (2). An inner dictionary that specifies a mapping FROM the index of dynamic axis in corresponding input/output TO the name that is desired to be applied on such axis of such input/output during export.
Example. if we have the following shape for inputs and outputs:
shape(input_1) = ('b', 3, 'w', 'h') and shape(input_2) = ('b', 4) and shape(output) = ('b', 'd', 5)
Then dynamic axes can be defined either as:
ONLY INDICES:
``dynamic_axes = {'input_1':[0, 2, 3], 'input_2':[0], 'output':[0, 1]}`` where automatic names will be generated for exported dynamic axes
INDICES WITH CORRESPONDING NAMES:
``dynamic_axes = {'input_1':{0:'batch', 1:'width', 2:'height'}, 'input_2':{0:'batch'}, 'output':{0:'batch', 1:'detections'}}`` where provided names will be applied to exported dynamic axes
MIXED MODE OF (1) and (2):
``dynamic_axes = {'input_1':[0, 2, 3], 'input_2':{0:'batch'}, 'output':[0,1]}``
keep_initializers_as_inputs (bool, default None) –
If True, all the initializers (typically corresponding to parameters) in the exported graph will also be added as inputs to the graph. If False, then initializers are not added as inputs to the graph, and only the non-parameter inputs are added as inputs.
This may allow for better optimizations (such as constant folding etc.) by backends/runtimes that execute these graphs. If unspecified (default None), then the behavior is chosen automatically as follows. If operator_export_type is OperatorExportTypes.ONNX, the behavior is equivalent to setting this argument to False. For other values of operator_export_type, the behavior is equivalent to setting this argument to True. Note that for ONNX opset version < 9, initializers MUST be part of graph inputs. Therefore, if opset_version argument is set to a 8 or lower, this argument will be ignored.
custom_opsets (dict<string, int>, default empty dict) – A dictionary to indicate custom opset domain and version at export. If model contains a custom opset, it is optional to specify the domain and opset version in the dictionary: - KEY: opset domain name - VALUE: opset version If the custom opset is not provided in this dictionary, opset version is set to 1 by default.
enable_onnx_checker (bool, default True) – If True the onnx model checker will be run as part of the export, to ensure the exported model is a valid ONNX model.
external_data_format (bool, default False) – If True, then the model is exported in ONNX external data format, in which case some of the model parameters are stored in external binary files and not in the ONNX model file itself. See link for format details: https://github.com/onnx/onnx/blob/8b3f7e2e7a0f2aba0e629e23d89f07c7fc0e6a5e/onnx/onnx.proto#L423 Also, in this case, argument ‘f’ must be a string specifying the location of the model. The external binary files will be stored in the same location specified by the model location ‘f’. If False, then the model is stored in regular format, i.e. model and parameters are all in one file. This argument is ignored for all export types other than ONNX.