TorchScript: Tracing vs. Scripting
PyTorch provides two methods to turn an nn.Module
into a
graph represented in TorchScript format: tracing and scripting.
This article will:
- Compare their pros and cons, with a focus on useful tips for tracing.
- Try to convince you that
torch.jit.trace
should be preferred overtorch.jit.script
for deployment of non-trivial models.
The second point might be an uncommon opinion: If I Google "tracing vs scripting", the first article recommends scripting as default. But tracing has many advantages. In fact, by the time I left, "tracing as default, scripting only when necessary" is the strategy all detection & segmentation models in Facebook/Meta products are deployed.
Why tracing is better? TL;DR: (i) it will not damage the code quality; (ii) its main limitations can be addressed by mixing with scripting.
Terminology¶
We start by disambiguate some common terminologies:
-
Export: refers to the process that turns a model written in eager-mode Python code into a graph that describes the computation.
-
Tracing: An export method. It runs a model with certain inputs, and "traces / records" all the operations that are executed into a graph.
torch.jit.trace
is an export API that uses tracing, used liketorch.jit.trace(model, input)
. See its tutorial and API. -
Scripting: Another export method. It parses the Python source code of the model, and compiles the code into a graph.
torch.jit.script
is an export API that uses scripting, used liketorch.jit.script(model)
. See its tutorial and API. -
TorchScript: This is an overloaded term
- It often refers to the representation / format of the exported graph.
- But sometimes it refers to the scripting export method.
To avoid confusion, I'll never use "TorchScript" alone in this article. I'll use "TS-format" to refer to the format, and "scripting" to refer to the export method.
Because this term is used with ambiguity, it may have caused the impression that "scripting" is the "official / preferred" way to create a TS-format model. But that's not necessarily true.
-
(Torch)Scriptable: A model is "scriptable" if
torch.jit.script(model)
succeeds, i.e. it can be exported by scripting. -
Traceable: A model is "traceable" if
torch.jit.trace(model, input)
succeeds for a typical input. -
Generalize: A traced model (returned object of
trace()
) "generalizes" to other inputs (different from the inputs given during tracing), if it can inference correctly when given other inputs. Scripted models always generalize. -
Dynamic control flow or data-dependent control flow: control flow where the operators to be executed depend on the input data, e.g. for a
Tensor
x:if x[0] == 4: x += 1
is a dynamic control flow.-
is NOT a dynamic control flow.
model: nn.Sequential = ...
for m in model:
x = m(x)is NOT a dynamic control flow.class A(nn.Module):
backbone: nn.Module
head: Optiona[nn.Module]
def forward(self, x):
x = self.backbone(x)
if self.head is not None:
x = self.head(x)
return x
The Cost of Scriptability¶
If anyone says "we'll make Python better by writing a compiler for it", you should immediately be alarmed and know that this is extremely difficult. Python is too big and too dynamic. A compiler can only support a subset of its syntax features and builtins, at best -- the scripting compiler in PyTorch is no exception.
What subset of Python does this compiler support?
A rough answer is: the compiler has
good support for the most basic syntax, but medium to no support for anything more complicated (classes, builtins like range
and zip
, dynamic types, etc.).
But there is no clear answer: even the developers of the compiler usually need to run the code to see if it can be compiled or not.
The incomplete Python compiler limits how users can write code. Though there isn't a clear list of constraints, I can tell from my experience what impact they have had on large projects: code quality is the cost of scriptability.
Impact on Most Projects¶
To make their code scriptable / compilable by the scripting compiler,
most projects choose to stay on the "safe side" to only use basic syntax of Python:
no/few custom structures, no builtins, no inheritance, no Union
, no **kwargs
, no lambda, no dynamic types, etc.
This is because these "advanced" compiler features are either not supported at all, or with "partial support" which is not robust enough: they may work in some cases but fail in others. And because there is no clear spec of what is supported, users are unable to reason about or workaround the failures. Therefore, eventually users move to and stay on the safe side.
The terrible consequence is that: developers stop making abstractions / exploring useful language features due to concerns in scriptability.
A related hack that many projects do is to rewrite part of the code for scripting: create a separate, inference-only forward codepath that makes the compiler happy. This also makes the project harder to maintain.
Impact on Detectron2¶
Detectron2 supports scripting, but the story was a bit different: it did not go downhill in code quality which we value a lot in research. Instead, with some creativity and direct support from PyTorch team (and some volunteered help from Alibaba engineers), we managed to make most models scriptable without removing any abstractions.
However, it is not an easy task: we had to add dozens of syntax fixes to the compiler, find creative workarounds, and develop some hacky patches in detectron2 that are in this file (which honestly could affect maintainability in the long term). I would not recommend other large projects to aim for "scriptability without losing abstractions" unless they are also closely supported by PyTorch team.
Recommendation¶
If you think "scripting seems to work for my project" so let's embrace it, I might advise against it for the following reasons, based on my past experiences with a few projects that support scripting:
-
What "works" might be more brittle than you think (unless you limit yourself to the basic syntax): Your code might happen to compile now, but one day you'll add a few innocent changes to your model and find that the compiler refuses it.
-
Basic syntax is not enough: Even if more complex abstractions don't appear necessary to your project at the moment, if the project is expected to grow, it will require more language features in the future.
Take a multi-task detector for example:
- There could be 10s of inputs, so it's preferable to use some structures/classes.
- The same data can have different representations (e.g. different ways to represent a segmentation mask),
which demands
Union
or more dynamic types. - There are many architectural choices of a detector, which makes inheritance useful.
Large, growing projects definitely need evolving abstractions to stay healthy.
-
Code quality could severely deteriorate: Ugly code starts to accumulate, because clean code sometimes just doesn't compile. Also, due to syntax limitations of the compiler, abstractions cannot be easily made to clean up the ugliness. The health of the project gradually goes downhill.
Below is a complaint in PyTorch issues. The issue itself is just one small papercut of scripting, but similar complaints were heard many times. The status-quo is: scripting forces you to write ugly code, so only use it when necessary.
Make a Model Trace and Generalize¶
The Cost of Traceability¶
What it takes to make a model traceable is very clear, and has a much smaller impact on code health.
-
First, neither scripting nor tracing works if the model is not even a proper single-device, connected graph representable in TS-format. For example, if the model has
DataParallel
submodules, or if the model converts tensors to numpy arrays and calls OpenCV functions, etc, you'll have to refactor it.Apart from this obvious constraint, there are only two extra requirements for traceability.
-
Input/output format
Model's inputs/outputs have to be
Union[Tensor, Tuple[Tensor], Dict[str, Tensor]]
or their nested combinations. Note that values in a dict have to belong to the same type.Similar constraints exist for scripting as well. However, in tracing the constraint does not apply to submodules: submodules can use any input/output format: dicts of Any, classes, kwargs, anything that Python supports. Only the top-level model is required to use the constraint format.
This makes the constraint very easy to satisfy. If the model uses richer formats, just create a simple wrapper around it that converts to/from
Tuple[Tensor]
. Detectron2 even automates this for all its models by a universal wrapper like this:Automatically Flatten & Unflatten Nested Containers has more details on how this adapter is implemented.outputs = model(inputs) # inputs/outputs are rich structure, e.g. dicts or classes
# torch.jit.trace(model, inputs) # FAIL! unsupported format
adapter = TracingAdapter(model, inputs)
traced = torch.jit.trace(adapter, adapter.flattened_inputs) # Can now trace the model
# Traced model can only produce flattened outputs (tuple of tensors):
flattened_outputs = traced(*adapter.flattened_inputs)
# Adapter knows how to convert it back to the rich structure (new_outputs == outputs):
new_outputs = adapter.outputs_schema(flattened_outputs) -
Symbolic shapes:
Expressions like
tensor.size(0)
,tensor.size()[1]
,tensor.shape[2]
are integers in eager mode, butTensor
s in tracing mode. Such difference is necessary so that during tracing, shape computation can be captured as symbolic operations in the graph. An example is given in the next section about generalization.Due to different return types, a model may be untraceable if parts of it assume shapes are integers. This usually can be fixed quite easily by handling both types in the code. A helpful function is
torch.jit.is_tracing
which checks if the code is executed in tracing mode.
That's all it takes for traceability - most importantly, any Python syntax is allowed in model implementation, because tracing does not care about syntax at all.
Generalization Problem¶
Just being "traceable" is not sufficient. The biggest problem with tracing, is that it may not generalize to other inputs. This problem happens in the following cases:
-
Dynamic control flow:
>>> def f(x):
... return torch.sqrt(x) if x.sum() > 0 else torch.square(x)
>>> m = torch.jit.trace(f, torch.tensor(3))
>>> print(m.code)
def f(x: Tensor) -> Tensor:
return torch.sqrt(x)In this example, due to dynamic control flow, the trace only keeps one branch of the condition, and will not generalize to certain (negative) inputs.
-
Capture variables as constants:
>>> a, b = torch.rand(1), torch.rand(2)
>>> def f1(x): return torch.arange(x.shape[0])
>>> def f2(x): return torch.arange(len(x))
>>> # See if the two traces generalize from a to b:
>>> torch.jit.trace(f1, a)(b)
tensor([0, 1])
>>> torch.jit.trace(f2, a)(b)
tensor([0]) # WRONG!
>>> # Why f2 does not generalize? Let's compare their code:
>>> print(torch.jit.trace(f1, a).code, torch.jit.trace(f2, a).code)
def f1(x: Tensor) -> Tensor:
_0 = ops.prim.NumToTensor(torch.size(x, 0))
_1 = torch.arange(annotate(number, _0), dtype=None, layout=0, device=torch.device("cpu"), pin_memory=False)
return _1
def f2(x: Tensor) -> Tensor:
_0 = torch.arange(1, dtype=None, layout=0, device=torch.device("cpu"), pin_memory=False)
return _0Intermediate computation results of a non-Tensor type (in this case, an int type) may be captured as constants, using the value observed during tracing. This causes the trace to not generalize.
In addition to
len()
, this issue can also appear in:.item()
which converts tensors to int/float.- Any other code that converts torch types to numpy/python primitives.
- A few problematic operators, e.g. advanced indexing.
-
Capture device:
>>> def f(x):
... return torch.arange(x.shape[0], device=x.device)
>>> m = torch.jit.trace(f, torch.tensor([3]))
>>> print(m.code)
def f(x: Tensor) -> Tensor:
_0 = ops.prim.NumToTensor(torch.size(x, 0))
_1 = torch.arange(annotate(number, _0), dtype=None, layout=0, device=torch.device("cpu"), pin_memory=False)
return _1
>>> m(torch.tensor([3]).cuda()).device
device(type='cpu') # WRONG!Similarly, operators that accept a
device
argument will remember the device used during tracing (this can be seen inm.code
). So the trace may not generalize to inputs on a different device. Such generalization is almost never needed, because deployment usually has a target device.
Let Tracing Generalize¶
The above problems are annoying and often silent (warnings, but no errors), but they can be successfully addressed by good practice and tools:
-
Pay attention to
TracerWarning
: In the first two examples above,torch.jit.trace
actually emits warnings. The first example prints:a.py:3: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect.
We can't record the data flow of Python values, so this value will be treated as a constant in the future.
This means that the trace might not generalize to other inputs!
if x.sum() > 0:Paying attention to these warnings (or even better, catch them) will expose most generalization problems of tracing.
Note that the "capture device" case does not print warnings because tracing was not designed to support such generalization at all.
-
Unittests for parity: Unittests should be done after export and before deployment, to verify that the exported model produces the same outputs as the original eager-mode model, i.e.
assert allclose(torch.jit.trace(model, input1)(input2), model(input2))
If generalization across shapes is needed (not always needed),
input2
should have different shapes frominput1
.Detectron2 has many generalization tests, e.g. this and this. Once a gap is found, inspecting the code of the exported TS-format model can uncover the place where it fails to generalize.
-
Avoid unnecessary "special case" conditions: Avoid conditions like
if x.numel() > 0:
output = self.layers(x)
else:
output = torch.zeros((0, C, H, W)) # Create empty outputsthat handles special cases such as empty inputs. Instead, improve
self.layers
or its underlying kernel so it supports empty inputs. This would result in cleaner code and also improve tracing. This is why I'm involved in many PyTorch issues that improve support for empty inputs, such as #12013, #36530, #56998. Most PyTorch operations work perfectly with empty inputs, so such branching is hardly needed. -
Use symbolic shapes: As mentioned earlier,
tensor.size()
returnsTensor
during tracing, so that shape computations are captured in the graph. Users should avoid accidentally turning tensor shapes into constants:- Use
tensor.size(0)
instead oflen(tensor)
because the latter is an int. For custom classes, implement a.size
method or use.__len__()
instead oflen()
, e.g. like here. - Do not convert sizes by
int()
ortorch.as_tensor
because they will capture constants. This helper function is useful to convert sizes into a tensor, in a way that works in both tracing and eager mode.
- Use
-
Mix tracing and scripting: they can be mixed together, so you can use scripting on the small portion of code that tracing does not work correctly. This can fix almost all problems of tracing. More on this below.
Mix Tracing and Scripting¶
Tracing and scripting both have their own problems, and the best solution is usually to mix them together. This gives us the best of both worlds.
To minimize the negative impact on code quality, we should use tracing for the majority of logic, and use scripting only when necessary.
-
Use
@script_if_tracing
: Insidetorch.jit.trace
, the@script_if_tracing
decorator can compile functions by scripting. Typically, this only requires a small refactor of the forward logic to separate the parts that need to be compiled (the parts with control flow):def forward(self, ...):
# ... some forward logic
@torch.jit.script_if_tracing
def _inner_impl(x, y, z, flag: bool):
# use control flow, etc.
return ...
output = _inner_impl(x, y, z, flag)
# ... other forward logicBy scripting only the parts that need it, the code quality damage is strictly smaller than making the entire model scriptable, and it does not affect the module's forward interface at all.
The function decorated by
@script_if_tracing
has to be a pure function that does not contain modules. Therefore, sometimes a bit more refactoring is needed:Before Refactoring After Refactoring # This branch cannot be compiled by
# @script_if_tracing, because it
# refers to `self.layers`
if x.numel() > 0:
x = preprocess(x)
output = self.layers(x)
else:
# Create empty outputs
output = torch.zeros(...)# This branch can be compiled by @script_if_tracing
if x.numel() > 0:
x = preprocess(x)
else:
# Create empty inputs
x = torch.zeros(...)
# Needs to make sure self.layers accept empty
# inputs. If necessary, add such condition branch
# into self.layers as well.
output = self.layers(x)In fact, for most vision models, dynamic control flow is needed only in a few submodules where it's easy to be scriptable. To show how rare it is needed, the entire detectron2 only has two functions decorated with
@script_if_tracing
due to control flows: paste_masks and heatmaps_to_keypoints, both for post-processing only. A few other functions are also decorated to generalize across devices (a very rare requirement). -
Use scripted / traced submodules:
model.submodule = torch.jit.script(model.submodule)
torch.jit.trace(model, inputs)In this example, suppose
submodule
cannot be traced correctly, we can script it before tracing. However I do not recommend it. If possible, I will suggest using@script_if_tracing
insidesubmodule.forward
instead, so that scripting is limited to the internals of the submodule, without affecting the module's interface.And similarly,
model.submodule = torch.jit.trace(model.submodule, submodule_inputs)
torch.jit.script(model)this uses a traced submodule during scripting. This looks nice, but is not so useful in practice: it will affect the interface of
submodule
, requiring it to only accept/returnTuple[Tensor]
-- this is a big constraint that might hurt code quality even more than scripting.A rare scenario where "tracing a submodule" is useful, is this:
class A(nn.Module):
def forward(self, x):
# Dispatch to different submodules based on a dynamic, data-dependent condition:
return self.submodule1(x) if x.sum() > 0 else self.submodule2(x)@script_if_tracing
cannot compile such control flow because it only supports pure functions. Ifsubmodule{1,2}
are complex and cannot be scripted, using traced submodules in a scripted parentA
is the best option. -
Merge multiple traces:
Scripted models support two more features that traced models don't:
- Control flow conditioned on attributes: a scripted module can have mutable attributes (e.g. a boolean flag) that affect control flows. Traced modules do not have control flows.
- Multiple methods: a traced module only supports
forward()
, but a scripted module can have multiple methods.
Actually, both features above are doing the same thing: they allow an exported model to be used in different ways, i.e. execute different sequences of operators as requested by the caller.
Below is an example scenario where such feature is useful: if
Detector
is scripted, the caller can mutate itsdo_keypoint
attribute to control its behavior, or callpredict_keypoint
method directly if needed.class Detector(nn.Module):
do_keypoint: bool
def forward(self, img):
box = self.predict_boxes(img)
if self.do_keypoint:
kpts = self.predict_keypoint(img, box)
@torch.jit.export
def predict_boxes(self, img): pass
@torch.jit.export
def predict_keypoint(self, img, box): passThis requirement is not seen very often. But if needed, how to achieve this in tracing? I have a solution that's not very clean:
Tracing can only capture one sequence of operators, so the natural way is to trace the model twice:
det1 = torch.jit.trace(Detector(do_keypoint=True), inputs)
det2 = torch.jit.trace(Detector(do_keypoint=False), inputs)We can then alias their weights (to not duplicate the storage), and merge the two traces into one module to script.
det2.submodule.weight = det1.submodule.weight
class Wrapper(nn.ModuleList):
def forward(self, img, do_keypoint: bool):
if do_keypoint:
return self[0](img)
else:
return self[1](img)
exported = torch.jit.script(Wrapper([det1, det2]))
Performance¶
If a model is both traceable and scriptable, tracing always generates same or simpler graph (therefore likely faster).
Why? Because scripting tries to faithfully represent your Python code, even some of it are unnecessary. For example: it is not always smart enough to realize that some loops or data structures in the Python code are actually static and can be removed:
|
This example is very simple, so it actually has workarounds for scripting (use tuple instead of list), or the loop might get optimized in a later optimization pass. But the point is: the graph compiler is not always smart enough. For complicated models, scripting might generate a graph with unnecessary complexity that's hard to optimize.
Concluding Thoughts¶
Tracing has clear limitations: I spent most of this article talking about the limitations of tracing and how to fix them. I actually think this is the advantage of tracing: it has clear limitations (and solutions), so you can reason about whether it works.
On the contrary, scripting is more like a black box: no one knows if it works before trying. I didn't mention a single trick about how to fix scripting: there are many of them, but it's not worth your time to probe and fix a black box.
Tracing has small blast radius: Both tracing and scripting affect how code can be written, but tracing has a much smaller blast radius, and causes much less damage:
- It limits the input/output format, but on the outer-most module only. (And this issue can be automatically solved as discussed above.)
- It needs some code changes to generalize (e.g. to mix scripting in tracing), but these changes only go into the internal implementation of the affected modules, not their interfaces.
On the other hand, scripting has an impact on:
- The interface of every module & submodule involved.
- IMO, this is the biggest damage: Advanced syntax features are needed in interfaces, and I'm not willing to compromise on interface design.
- This may end up affecting training as well because interface is often shared between training and inference.
- Pretty much every line of code in the inference forward path.
Having a large blast radius is why scripting can do great harm to code quality.
Control flow vs. other Python syntax: PyTorch is loved by its users because they can "just write Python", and most importantly write Python control flows. But other syntax of Python are important as well. If being able to write Python control flow (scripting) means losing other great syntax, I'd rather give up on the ability to write Python control flow.
In fact, if PyTorch is less obsessed with Python control flow, and offers me
symbolic control flows such as torch.cond
like this (similar to the API of tf.cond
):
|
Then f
could be traced correctly and I would be happy to use this, no longer having to worry
about scripting.
TensorFlow AutoGraph
is a great example that automates this idea.