Exporting to ONNX using torch.onnx API

Learn via video courses
Topics Covered

Overview

ONNX is a powerful and open standard for representing machine learning models to prevent friction between the various available frameworks. It empowers machine learning developers by ensuring their models will be usable in the long run. This article is a tutorial that explains the integration of PyTorch with ONNX as we learn to export PyTorch models to ONNX using the torch.onnx API.

Introduction

PyTorch is one of the most popular deep learning libraries used heavily by researchers and in industry settings by developers to build Deep Learning based applications.

The final result of developing and saving neural networks in PyTorch is a .pt or .pth file that can be loaded back for inference.

However, it is a challenge to take the .pt file from a model trained in a Python environment into production. Furthermore, to export these models for deployment using any other tools or frameworks or high-performance environments like C++, we need to convert the .pt model files into a suitable format compatible with a vast majority of machine learning tools available.

ONNX provides us with such a format; we will look at it next.

What is ONNX?

The AI developer community is growing faster than ever, with new tools being developed to make AI more accessible to users coming from all domains with varied use cases.

Developers worldwide are working on building tools and frameworks that facilitate working with and implementing neural networks with as much ease as possible. However, while the ecosystem of tools is expanding at a faster rate with popular libraries like PyTorch, CNTK, MXNet, Caffe2, and so on, it often brings interoperability issues causing ML practitioners to lock in or restrict themselves to one framework or tool due to friction between toolchains.

ONNX stands for Open Neural Network Exchange and is an open format built to represent machine learning models. It works to solve this by defining a common set of operators that are the building blocks of machine learning and deep learning models and a common file format, enabling more of these tools to work together by enforcing portability.

ONNX allows sharing of models and thus allows developers to use models with a variety of frameworks, tools, runtimes, and compilers, thus promoting collaborative innovation and developments in the AI sector.

ONNX hence enhances interoperability by enabling the developers to use the right combinations of tools smoothly for their project so that the transition from research to code implementation is as quick as possible without bottlenecks from toolchains.

Apart from the interoperability, the ONNX format is supported to run by multiple accelerated runtimes, thus also providing optimized execution of the neural networks during inference or deployment.

Let us now have a brief walkthrough of the technical design of ONNX.

Technical Design of ONNX

ONNX defines an "extensible" computation graph model (computation graphs are at the core of neural network building) and that of built-in operators and standard data types.

Each computation graph within ONNX is defined as a list of nodes that form an acyclic graph. These Nodes have one or more inputs and one or more outputs, and each node represents a call to an operator.

Operators are implemented externally concerning the graph, and these built-in operators defined within ONNX are portable (or usable) across different frameworks/libraries.

Every ML framework that supports ONNX will implement these operators with suitable data types.

The graph also contains metadata to document its purpose, author, etc.

Exporting PyTorch Models to ONNX using the torch.onnx API

With that said, we will now be learning about the torch.onnx API that is meant for exporting PyTorch models to ONNX. The exported model can then be used with any runtimes supporting ONNX. For example, here is the full list of ONNX supporting runtimes.

Create a Network

We will begin by creating a neural network in PyTorch using the torch.nn.Module class, like so:

We have defined three linear layers and use the ReLU activation function between them to introduce non-linearities in the network.

Prepare a Data Set

Let us now prepare our dataset to be fed into the network for training. Using the datasets module, we will use the Iris data set from scikit-learn. We are using np.zeros method, we preprocess the labels into one hot encoded vector and then split the dataset.

We will be using 25% of the data set as test data for measuring the prediction accuracy on anonymous data. The rest of the data will be used as the training data for training the neural network model.

We will also convert our features X and the corresponding true labels y into a format compatible with and expected by PyTorch using autograd.Variable.

Instantiate the Network

Let us now create an instance of our custom model class and define the loss function (criterion) for backpropagating the errors and the optimization algorithm for updating the learnable model parameters.

We will be using the Adam optimizer here from the torch.optim package.

Loop to Update Parameters

Let us now feed the training data to our neural network for a certain number of epochs for the training process, like so:

Export the Model to ONNX & Compress

We will now use the torch.onnx API to export the PyTorch model into an ONNX-compatible file that could now be run using the ONNX runtimes or any of the runtimes supporting ONNX formats. After running the commands below, two new files should be created in the current working directory called pytorch_mlp.onnx and pytorch_mlp.onnx.tgz.

Benefits of Using ONNX

Let us now glance through the benefits of having a torch.onnx API for exporting PyTorch models to ONNX format -

  • Deploy anywhere - Converting PyTorch models to ONNX format allows us to take our models out of Python notebooks and deploy them anywhere on the cloud, desktop, mobile, IoT, and even in the browser.
  • Lower latency, higher throughput - The various ONNX runtimes provide acceleration through optimized performance, thus reducing execution costs and enhancing user experience.
  • Python not required - PyTorch models are often trained using the Python API, which poses a significant obstacle in deploying trained PyTorch models to many production environments, especially Android and iOS mobile devices. ONNX Runtime is optimized for production environments and provides APIs in C/C++, C#, Java, and Objective-C, thus bridging our PyTorch models trained in a Python environment to successful production deployment in high-performance environments.

Avoiding Pitfalls of ONNX

  • PyTorch models can be built using NumPy or Python types and functions, but during the torch.onnx.export call, any variables of types other than torch.Tensor (NumPy or Python types) are converted to constants, thus producing wrong results if the values of such types do not remain constant throughout the execution of an inferential run or depending on the value of the input tensor.

For example, the following operation using NumPy arrays should be replaced by the PyTorch equivalent to avoid any pitfalls -

Similarly, using the .item() can pose similar issues for single-valued tensors as it returns a Python object. Rather, it is recommended to use PyTroch's support for the implicit casting of single-element tensors to preserve the variable status, like so -

  • Avoid Tensor.data Using the Tensor.data field can produce an incorrect trace during the export call and therefore, an incorrect ONNX graph will be created. It is recommended to use torch.Tensor.detach() instead.

Limitations of ONNX

  • Data Types

    • Due to limited support for nested sequences in ONNX, Certain operations involving tuples and lists are not supported. In particular, appending a tuple to a list is not supported.
    • Any computation that depends on the value of a dictionary or a string input will be replaced with the constant value seen during the one-traced execution.
    • Any output in the form of a Python dictionary object will be silently replaced with a flattened sequence of its values (and the keys will be removed). So, for example {"a": 1, "b": 2} becomes (1, 2).
    • Any output that is str will be silently removed.
  • Differences in Operator Implementations Due to differences in implementations of operators on different runtimes, the same exported model may produce different results from each other or PyTorch when run on different runtimes. (Normally, these differences are numerically small, so these should not be a cause of concern).

  • Unsupported Tensor Indexing Patterns Some indexing patterns for PyTorch tensors cannot be exported during the export call. A full list is available here.

  • Unsupported Operators An error is caused if we try to export a model containing operators that are not in-built and hence do not have a native symbolic function corresponding to them. Though not primitively supported, such operators can be manually registered using custom symbolic functions.

Conclusion

In the following article, we learned about the utility of the torch.onnx API. In particular,

  • We first work through an introduction highlighting the importance of having a common interface for machine learning models.
  • We get introduced to ONNX and throw light on its technical design.
  • Then we learned to use the torch.onnx API and undersrand the benefits of exporting PyTorch models to ONNX.
  • We also learn the various pitfalls and limitations of the torch.onnx API.

Read More: