# Automatically Flatten & Unflatten Nested Containers

This post is about a small functionality that is found useful in TensorFlow / JAX / PyTorch.

Low-level components of these systems often use a plain list of values/tensors as inputs & outputs. However, end-users that develop models often want to work with more complicated data structures: Dict[str, Any], List[Any], custom classes, and their nested combinations. Therefore, we need bidirectional conversion between nested structures and a plain list of tensors. I found that different libraries invent similar approaches to solve this problem, and it's interesting to list them here.

## Nested Containers Are Useful Abstractions

Though many simple deep learning models just needs a few inputs/outputs tensors, nested containers are useful abstractions in advanced models. This is because many concepts are naturally represented by >1 tensors, e.g.:

• A sparse tensor consists of values + indices
• A masked tensor (common in transformers) is represented by a tensor + its binary mask
• A segmentation mask can be represented in different ways:
• shape + bounding box + mask within the box (aka "RoIMask")
• shape + list of polygons
• shape + run-length encoding
• Detected objects in an image are represented by boxes + scores + labels + many possible attributes
• A list of variable-length vectors may be represented by a concatenated vector + a length vector, i.e.:

When a frequently-used concept has natural complexity like above, representing it in a flat structure (e.g. Dict[str, Tensor]) consisting of only regular tensors may result in ugly code. A multi-level nested structure sometimes becomes helpful. Take sparse tensor as a simple example:

Use nested containers Use a flat Dict[str, Tensor]
Representation
{"a": SparseTensor, "b": Tensor}
SparseTensor can be a namedtuple/dataclass, or a new class.
{"a_values": Tensor, "a_indices": Tensor, "b": Tensor}
Sanity check SparseTensor class can guarantee both tensors exist and follow certain contracts (e.g. their shapes match) Need to check a_{values,indices} co-exist in the dict
Pass to another function Pass x["a"] directly Extract x["a_values"], x["a_indices"] and pass both
Operations SparseTensor class can have methods that work like regular tensors, e.g.
y = x["a"] + 1
Need to implement many new functions, e.g.
y = add_sparse(x["a_values"], x["a_indices"], 1)

## Bidirectional Conversion

Despite the benefits, lower-level stacks often ignore these abstractions and choose to use a "flat" interface: their inputs & outputs are a flat list of values / Tensors. This is because: (i) the abstraction may no longer be useful in lower level; (ii) a simple structure simplifies their implementation; (iii) a flat list is a data structure available even in lower-level languages & systems.

Therefore, conversion from a nested structure to a plain list of values is important. This is often referred to as "flatten". It is pretty straightforward to flatten a container recursively -- like the following flatten function:

The inverse of flatten is also important: given new values [x2, y2, z2], we want the unflatten function below to construct obj2 that has the same structure as obj.

unflatten is a very handy utility. For example, to create a clone of obj on a different device, we simply do this:

Without unflatten, every such functionality needs to be reimplemented as a recursive function, like PyTorch's pin_memory.

## Implementation of unflatten

How do we implement unflatten? Apparently, we need to give it a representation of structure (noted as a placeholder ??? in the above code). There are two high-level approaches to solve this problem:

• Schema-based: when flattening a container, explicitly record its structure/schema to be used for unflatten. Its API may look like this:

Examples: Detectron2's flatten_to_tuple, TensorFlow's FetchMapper, JAX's pytree.

• Schema-less: use the entire nested container as an implicit representation of structure. Its interface looks like this:

Examples: TensorFlow's tf.nest. DeepMind's dm-tree.

The two approaches have some pros and cons:

• The schema-less approach has simpler API and implementation.
• The schema-based approach likely has a more memory-efficient representation of schema, compared to using an entire container as schema.
• An explicit schema representation allows more functionalities to be added by customizing the representation.

## Applications

### JAX Pytree

JAX's low level components accept/return flat tensors, so functions can be transformed and optimized more easily. Since end-users need nested containers, JAX transformations supports pytree containers, which by default includes flattening & unflattening for common Python containers. It further allows users to register custom classes by register_pytree_node.

Pytree uses a schema-based implementation that we already show-cased above.

When we need to independently process each leaf of the container, JAX provides another handy function tree_map:

PyTorch also adds a similar implementation of pytree at here that is used in its FX tracing.

### Detectron2 TracingAdapter

torch.jit.trace(model, inputs) executes the model with given inputs, and returns a graph representation of the model's execution. This is one of the most common methods (and the best IMO) how PyTorch models are exported today. However, it limits model's input & output format.

In order to trace models with more complicated inputs & outputs, I created the TracingAdapter tool in detectron2, that flattens/unflattens a model's inputs and outputs into simple Tuple[Tensor] to make it traceable. A minimal implementation of it may look like this:

where flatten uses a schema-based implementation that can be found in this file. Coincidentally, its interface looks like JAX's pytree:

Perception models in Meta accept a wide range of inputs/outputs formats: they may take any number of images plus auxiliary data as inputs, and predict boxes, masks, keypoints or any other interesting attributes as outputs. But deployment prefers a flat interface for optimizability and interoperability. TracingAdapter's automatic flattening and unflattening mechanism has freed engineers from writing format conversion glue code when deploying these models.

In addition to deployment, TracingAdapter is also useful in a few other places to smooth the experience of torch.jit.trace:

### TensorFlow tf.nest

tf.nest.flatten and tf.nest.pack_sequence_as implement schema-less flattening and unflattening.

The unflatten function requires a container, and it will flatten this container on-the-fly while simultaneously "pack" flat values into the structure of this container. Here is an official example (note that dict values are ordered by keys):

tf.nest.{flatten,pack_sequence_as} are widely used in TensorFlow because many low-level components have a flat interface, especially for interop with C APIs.

tf.nest.map_structure has the same functionality as JAX's tree_map.

### TensorFlow FetchMapper

TFv1's session.run(fetches) supports fetching nested containers. This is demonstrated in an example from the official documentation:

This powerful interface exists in TF's Python client only. The client interacts with the C API's TF_SessionRun which only accepts a plain array of inputs/outputs. Therefore, the client needs to:

1. Flatten the container to a plain array of tensors
2. Send this array to the C API to obtain an array of results
3. Unflatten / reconstruct the container using the results

The flatten/unflatten logic uses a schema-based implementation in the client's FetchMapper. This implementation is a bit more complicated due to an extra guarantee that the flattened tensors are unique. (This is to ensure the client won't fetch the same tensor twice in one call; this cannot be done by using tf.nest.)

In addition to builtin Python containers, FetchMapper supports a few other TF containers (such as SparseTensor) and can be extended to new containers by registering conversion functions.

### DeepMind tree library

DeepMind has a tree library as a standalone alternative to tf.nest:

deepmind/tree tf.nest jax.tree_util
tree.flatten tf.nest.flatten jax.tree_util.tree_flatten
tree.unflatten_as tf.nest.pack_sequence_as jax.tree_util.tree_unflatten
tree.map_structure tf.nest.map_structure jax.tree_util.tree_map