Converting PyTorch to TorchScript

Learn via video courses
Topics Covered

Overview

Deploying machine learning models in production is a challenging and one of the most crucial steps in developing machine learning-based applications. Often, it is required to work with environments that offer optimized computations in production. To this end, this article explores the TorchScript mode, one of the two ways PyTorch can be used to develop deep neural networks. We will learn through code examples how to convert PyTorch models to TorchScript mode altogether with a theoretical introduction to the components associated with the TorchScript mode in PyTorch.

Introduction to the PyTorch Ecosystem

The PyTorch Ecosystem features two types of modes for different purposes - the eager mode and the script mode.

The eager mode is suitable for research settings and is built for faster prototyping of model architectures, training, and experimentation.

The other mode is called the script mode, which is suitable for deployment purposes and consists of two components called PyTorch JIT and TorchScript. Before looking into the two components called PyTorch JIT and TorchScript, we will examine the script mode's importance.

Importance of the Script Mode

The script mode serves two purposes to be able to make PyTorch models easily deployable in production -

  1. Portability By portability, script mode allows us to convert the code for our models in a format independent of Python and hence does not require a Python runtime to execute. Hence it enables us to get rid of the dependence on Python runtime.

  2. Performance PyTorch JIT uses runtime information to optimize TorchScript modules by automating layer fusion, quantization, and sparsification optimizations.

Let us now understand what TorchScript and PyTorch JIT exactly are.

What is TorchScript?

TorchScript is a statically typed Python subset optimized for machine learning models or neural nets. It only provides a subset of types needed to express neural net models.

Code in TorchScript can either be written directly (using the @torch.jit.script decorator) or generated automatically from Python code via tracing (we will learn more about this shortly).

When writing code in TorchScript mode directly using the @torch.jit.script decorator, one needs to ensure to only use the subset of Python operators supported in TorchScript. Refer to Builtin Functions for a complete reference of the available Pytorch tensor methods, modules, and functions in the TorchScript mode.

Importance of TorchScript in Deep Learning

TorchScript is a way to create serializable and optimizable models from PyTorch code written in Python. Models can be saved as a TorchScript program from a Python process, and the saved models can be loaded back into a process without Python dependency.

  • Hence TorchScript helps to remove the Python dependency by enabling users to transition a model written as a pure Python program to a TorchScript program that can then be run independently from Python, enabling us to use high-performance computing environments like C++ in production.

  • This allows us to train our models in Python and export them using TorchScript to some other production environment.

  • Since Python programs are likely to cause performance and multi-threading issues in production setup, TorchScript becomes very useful in migrating our models from Python to some other language while deploying.

  • Essentially, TorchScript does this by providing tools to capture the definition of our model, even when PyTorch relies on dynamic graph creation rather than static - that is, graphs in PyTorch are flexible and created on the fly as the computations are done.

  • Other than this, code in the TorchScript format can be invoked in its own restricted Python interpreter that does not acquire the Global Interpreter Lock (GIL), and so can process many requests on the same instance simultaneously.

Let us now understand what PyTorch JIT Compiler is.

What is PyTorch JIT Compiler?

The PyTorch JIT (just in time) Compiler consumes the TorchScript code and performs runtime optimization on our model’s computation. It is an optimized compiler that features the following -

  • It is a threadsafe interpreter.
  • It supports easy-to-write custom transformations.
  • It can be used for more than just inference; it also supports auto differentiation.

Hence, modules that are compatible with JIT can be compiled by it rather than interpreted, allowing various optimizations and improved performance during both development of models (training) and the production of models (deployment or inference).

There are two broad ways to make our PyTorch modules compatible with JIT, that is, to convert them to the TorchScript mode - tracing via the torch.jit.trace API and scripting via the torch.jit.script API of which tracing is a little easier than the latter but comes at the cost of some limitations.

Methods to Convert PyTorch to TorchScript

Tracing

Tracing is an export technique that runs our model with certain inputs and traces or records all operations executed into the model's graph.

The API can be simply used as torch.jit.trace(model, input).

A model is called "traceable" if torch.jit.trace(model, input) succeeds for standard input.

A simple example of tracing in PyTorch follows. Here we first define a custom model class and then instantiate it. We then trace the model instance by passing some sample inputs to it, like so -

Output:

During tracing, the Python code is automatically converted into the subset (TorchScript) of Python by recording only the actual operators on tensors and simply executing and discarding the other surrounding Python code.

torch.jit.trace invokes the Module, records the computations that occur when the Module was run on the inputs, and then creates an instance of the torch.jit.ScriptModule, essentially code written in plain Python converted to the TorchScript mode.

TorchScript also records the model definitions in what is called an Intermediate Representation (or IR) or a graph that we can access with the .graph property of the traced model, like so -

Output:

A much cleaner Python syntax interpretation of the code can be accessed in the following way -

Output:

Generalizability

We also need to ensure that the traced model "generalizes" to other inputs different from those given during the model tracing. We will do it by comparing the output produced by the traced model to the output produced by the plain Python module and see if it can infer correctly when given other inputs, like so -

Output:

Generalizability during tracing is crucial and needs to be ensured.

Let us now look at the other way of converting PyTorch modules to TorchScript format - the scripting technique.

Scripting

The second way of converting PyTorch modules to TorchScript format is scripting, which can be used with the torch.jit.script API.

With scripting, we can write our code directly in TorchScript mode introducing a certain level of verbosity. The support offered by the scripting technique is much wider than that offered by the tracing technique.

A simple example demonstrating scripting and also why scripting might be required over tracing is as follows -

Output:

As can be seen in the warning produced in the output from the above code that attempts tracing, the control flow was totally erased in the traced model and hence the IR is incorrect. This happened because of how the tracing technique functions - it runs the model code, records the operations "that happen" and then constructs a ScriptModule that does just that and hence removes operations like control flow.

To get around this, we have scripting that directly analyzes our models written in Python source code to transform it into TorchScript mode.

The following code shows how scripting captured the model graph correctly.

Output:

To holistically understand the differences between tracing and scripting, move on to the next section that explores it.

Also, there is another way to mix both approaches to leverage the best out of both. The official tutorials cover a nice example here.

Difference Between Scripting and Tracing

In this section, we will detail the major points that differentiate the two techniques to convert PyTorch modules to the TorchScript` format - the tracing and the scripting techniques- from each other and highlight the benefits of using one over the other.

  • Tracing lets us use the dynamic tensor ops in Python as it records tensor operations. It cannot trace control flow, data structures, or Python constructs.

  • On the other hand, scripting, with some code changes, supports all of the features that are compatible with the JIT compiler, a full list of which can be found here. In addition, it preserves the Python control flow and offers wider support for data structures like lists or dictionaries.

  • The generalizability of traced models needs to be ensured explicitly, while scripted models are always generalizable.

  • Although scripting is a good way to support advanced graphs containing control flow etc., there are a plethora of things that are not supported by the JIT compiler, like classes, builtins like range and zip, dynamic types, etc. hence it limits us in our ability to use abstract types and advanced features of python as a programming language which eventually means that our code can get messy more often than not.

In any case, for Scripting and Tracing to work properly, the model must be a connected graph representable in the TorchScript format.

Saving and Loading TorchScript Modules

Saving and loading the scripted and the traced models can be done using the save and load functions, like so -

Output:

Conclusion

That was all about this article. Let us now conclude with points that we studied in this article -

  • We first understood the two modes of model execution offered by the PyTorch ecosystem - the eager mode and the script mode, and understood the importance of each in deep learning.
  • After this, we understood what TorchScript is, its importance in the PyTorch ecosystem, and how it can be used with the PyTorch JIT Compiler to deploy PyTorch models developed in Python to a high-performance environment.
  • After this, we thoroughly understood the two ways to convert models written in plain Python using PyTorch to the TorchScript mode - the tracing and the scripting techniques- and learned to implement each using code examples demonstrating the conversion of PyTorch models in plain Python to the TorchScript mode.
  • We also understood the difference between these two ways of converting models to TorchScript mode while highlighting the pros and cons of each.
  • Lastly, we learned how the TorchScript models could be saved and loaded back for later use.