Automatically Flatten & Unflatten Nested Containers
This post is about a small functionality that is found useful in TensorFlow / JAX / PyTorch.
Low-level components of these systems often use a plain list of values/tensors
as inputs & outputs.
However, end-users that develop models often want to work with more
complicated data structures:
Dict[str, Any]
, List[Any]
, custom classes, and their nested combinations.
Therefore, we need bidirectional conversion between nested structures and a plain list of tensors.
I found that different libraries invent similar approaches to solve this problem, and it's interesting to list them here.
Nested Containers Are Useful Abstractions¶
Though many simple deep learning models just needs a few inputs/outputs tensors, nested containers are useful abstractions in advanced models. This is because many concepts are naturally represented by >1 tensors, e.g.:
- A sparse tensor consists of values + indices
- A masked tensor (common in transformers) is represented by a tensor + its binary mask
- A segmentation mask can be represented in different ways:
- single whole-image bitmask tensor
- shape + bounding box + mask within the box (aka "RoIMask")
- shape + list of polygons
- shape + run-length encoding
- Detected objects in an image are represented by boxes + scores + labels + many possible attributes
- A list of variable-length vectors may be represented by a concatenated vector + a length vector, i.e.:
[[1, 2, 3], [42], [6, 6]] --> [1, 2, 3, 42, 6, 6], [3, 1, 2]
When a frequently-used concept has natural complexity like above, representing it
in a flat structure (e.g. Dict[str, Tensor]
) consisting of only regular tensors may result in ugly code.
A multi-level nested structure sometimes becomes helpful.
Take sparse tensor as a simple example:
Use nested containers | Use a flat Dict[str, Tensor] |
|
---|---|---|
Representation | {"a": SparseTensor, SparseTensor can be a namedtuple/dataclass, or a new class. |
{"a_values": Tensor, |
Sanity check | SparseTensor class can guarantee both tensors exist and follow certain contracts (e.g. their shapes match) |
Need to check a_{values,indices} co-exist in the dict |
Pass to another function | Pass x["a"] directly |
Extract x["a_values"], x["a_indices"] and pass both |
Operations | SparseTensor class can have methods that work like regular tensors, e.g. y = x["a"] + 1 |
Need to implement many new functions, e.g. y = add_sparse(x["a_values"], x["a_indices"], 1) |
Bidirectional Conversion¶
Despite the benefits, lower-level stacks often ignore these abstractions and choose to use a "flat" interface: their inputs & outputs are a flat list of values / Tensors. This is because: (i) the abstraction may no longer be useful in lower level; (ii) a simple structure simplifies their implementation; (iii) a flat list is a data structure available even in lower-level languages & systems.
Therefore, conversion from a nested structure to a plain list of values is important.
This is often referred to as "flatten".
It is pretty straightforward to flatten a container recursively -- like the following flatten
function:
|
The inverse of flatten
is also important: given new values [x2, y2, z2]
,
we want the unflatten
function below to construct obj2
that has the same
structure as obj
.
|
unflatten
is a very handy utility. For example, to create a clone of obj
on a different device, we simply do this:
|
Without unflatten
, every such functionality needs to be reimplemented as a recursive
function, like PyTorch's pin_memory
.
Implementation of unflatten
¶
How do we implement unflatten
?
Apparently, we need to give it a representation of structure (noted as a placeholder ???
in the above code).
There are two high-level approaches to solve this problem:
-
Schema-based: when flattening a container, explicitly record its structure/schema to be used for unflatten. Its API may look like this:
>>> from jax.tree_util import tree_flatten, tree_unflatten
>>> obj = [3, ([5, 6], {"name": [7, 9], "name2": 3})]
>>> res, schema = tree_flatten(obj)
>>> res # Flattened results:
[3, 5, 6, 7, 9, 3]
>>> schema # An explicit representation of the container's structure
PyTreeDef([*, ([*, *], {'name': [*, *], 'name2': *})])
>>> # Construct a nested container using the given values and the structure/schema:
>>> tree_unflatten(schema, [1, 2, 3, 4, 5, 6])
[1, ([2, 3], {'name': [4, 5], 'name2': 6})]Examples: Detectron2's
flatten_to_tuple
, TensorFlow'sFetchMapper
, JAX'spytree
. -
Schema-less: use the entire nested container as an implicit representation of structure. Its interface looks like this:
>>> import tensorflow as tf
>>> obj = [3, ([5, 6], {"name": [7, 9], "name2": 3})]
>>> tf.nest.flatten(obj) # Flattened results:
[3, 5, 6, 7, 9, 3]
>>> # Construct a nested container that has same structure as obj, using the given list of values:
>>> tf.nest.pack_sequence_as(obj, [1, 2, 3, 4, 5, 6])
[1, ([2, 3], {'name': [4, 5], 'name2': 6})]Examples: TensorFlow's
tf.nest
. DeepMind'sdm-tree
.
The two approaches have some pros and cons:
- The schema-less approach has simpler API and implementation.
- The schema-based approach likely has a more memory-efficient representation of schema, compared to using an entire container as schema.
- An explicit schema representation allows more functionalities to be added by customizing the representation.
Applications¶
JAX Pytree¶
JAX's low level components accept/return flat tensors, so functions can be transformed and optimized more easily.
Since end-users need nested containers, JAX transformations supports pytree containers,
which by default includes flattening & unflattening for common Python containers.
It further allows users to register custom classes by
register_pytree_node
.
Pytree uses a schema-based implementation that we already show-cased above.
When we need to independently process each leaf of the container, JAX provides another handy
function tree_map
:
|
PyTorch also adds a similar implementation of pytree at here that is used in its FX tracing.
Detectron2 TracingAdapter
¶
torch.jit.trace(model, inputs)
executes the model with given inputs, and returns a graph representation
of the model's execution.
This is one of the most common methods (and the best IMO) how PyTorch models are exported today.
However, it limits model's input & output format.
In order to trace models with more complicated inputs & outputs,
I created the TracingAdapter
tool in detectron2, that flattens/unflattens a model's inputs and outputs into simple Tuple[Tensor]
to make it traceable.
A minimal implementation of it may look like this:
|
where flatten
uses a schema-based implementation that can be found in this file.
Coincidentally, its interface looks like JAX's pytree:
|
Perception models in Meta accept a wide range of inputs/outputs formats:
they may take any number of images plus auxiliary data as inputs, and
predict boxes, masks, keypoints or any other interesting attributes as outputs.
But deployment prefers a flat interface for optimizability and interoperability.
TracingAdapter
's automatic flattening and unflattening mechanism has freed engineers from
writing format conversion glue code when deploying these models.
In addition to deployment, TracingAdapter
is also useful in a few other places to smooth
the experience of torch.jit.trace
:
- Flop counting: fvcore's flop counter
uses tracing to obtain a graph of operators.
To let it support counting of complex models,
wrapping the model with
TracingAdapter
is the easiest way. - Tensorboard graph visualization: PyTorch's tensorboard writer has a
add_graph
method that visualizes the graph structure in tensorboard. The method requires flattened inputs, thereforeTracingAdapter
can be used like this. - PyTorch's ONNX export is also based on tracing. So
TracingAdapter
is useful as well, e.g. here.
TensorFlow tf.nest
¶
tf.nest.flatten
and tf.nest.pack_sequence_as
implement schema-less flattening and unflattening.
The unflatten function requires a container, and it will flatten this container on-the-fly while simultaneously "pack" flat values into the structure of this container. Here is an official example (note that dict values are ordered by keys):
|
tf.nest.{flatten,pack_sequence_as}
are widely used in TensorFlow because many low-level components have a flat interface, especially for
interop with C APIs.
|
tf.nest.map_structure
has the same functionality as JAX's tree_map
.
TensorFlow FetchMapper
¶
TFv1's session.run(fetches)
supports fetching nested containers.
This is demonstrated in an example from the
official documentation:
|
This powerful interface exists in TF's Python client only.
The client interacts with the C API's TF_SessionRun
which only accepts a plain array of inputs/outputs.
Therefore, the client needs to:
- Flatten the container to a plain array of tensors
- Send this array to the C API to obtain an array of results
- Unflatten / reconstruct the container using the results
The flatten/unflatten logic uses a schema-based implementation in the client's FetchMapper
.
This implementation is a bit more complicated due to
an extra guarantee that
the flattened tensors are unique. (This is to ensure the client won't fetch the same tensor twice in one call;
this cannot be done by using tf.nest
.)
In addition to builtin Python containers, FetchMapper
supports a few other TF containers
(such as SparseTensor
) and can be extended to new containers by registering conversion functions.
DeepMind tree
library¶
DeepMind has a tree
library as a standalone alternative to tf.nest
:
deepmind/tree |
tf.nest |
jax.tree_util |
---|---|---|
tree.flatten |
tf.nest.flatten |
jax.tree_util.tree_flatten |
tree.unflatten_as |
tf.nest.pack_sequence_as |
jax.tree_util.tree_unflatten |
tree.map_structure |
tf.nest.map_structure |
jax.tree_util.tree_map |