Demystify RAM Usage in Multi-Process Data Loaders
A typical PyTorch training program on 8 GPUs with 4 dataloader
workers per GPU would create at least
All code examples and experiment results are available on github at ppwwyyxx/RAM-multiprocess-dataloader. The content is not specific to PyTorch: it applies to any user of Python's multiprocessing library on Linux.
Motivation for In-RAM Data¶
Datasets for machine learning are usually not stored in RAM. But it's common to store their "metadata" in RAM, and this may still cause nontrivial RAM usage. The metadata could be:
- For ImageNet dataset: A million file names and their labels.
- For COCO dataset: 100k file names and their bounding boxes, segmentations, etc.
As a concrete case, loading the metadata of COCO training set into Python takes ~2.4G of RAM:
|
We obviously don't want to replicate this 2.4G of RAM across all processes.
In-RAM metadata is needed for flexibility¶
We acknowledge that there are ways to offload these metadata to disk. For example, people sometimes do:
- Store all the metadata together with raw data on disk, so metadata are not stored in RAM.
- Read sequentially from a combined single-file dataset, so that file names or indices are not stored in RAM.
By doing these, the RAM usage of a dataset becomes negligible. However, these methods will sacrifice flexibility and capabilities, such as random-access, perfect shuffle, merging datasets arbitrarily, custom subsampling support, etc. Notably, PyTorch's commonly used map-style datasets support random access & sampling. All of these capabilities require certain metadata in RAM.
This article ignores any of these offloading methods. Instead, we'll discuss how to reduce the RAM usage without moving these data out of RAM. The idea is simple: we'll try to let all processes share a single copy of the dataset.
Measure RAM Usage¶
First let's build tools to measure RAM usage - which is not as easy as it sounds.
Common tools like top -p PID
or psutil.Process(PID).memory_info()
obtains memory statistics from /proc/{PID}/statm
or /proc/{PID}/status
, but they are insufficient for our analysis. Instead, we'll use the information provided in
/proc/{PID}/smaps
: per-memory-mapping RAM usage information, documented in this man page/proc/{PID}/smaps_rollup
: aggregation of data fromsmaps
We'll derive the following important measurements from it:
- USS (Unique Set Size): RAM that's unique/private to this process, i.e. not shared with any other process. This is obtained by the sum of "private_*" entries in
smaps
. - Shared: RAM in this process that's also shared with other processes. This is obtained by the sum of "shared_*" entries in
smaps
. - Shared_File: RAM that's shared with other processes through files. It should be no larger than "Shared".
- This number should be almost the same as the "SHR" column in
top/htop
.
- This number should be almost the same as the "SHR" column in
- RSS (Resident Set Size): All memory that this process holds in RAM. RSS = USS + Shared.
- PSS (Proportional Set Size): Like RSS, but it avoids overcounting shared memory multiple times across all processes that are sharing it. It's basically "USS + Shared / (number of processes sharing it)". By definition, we should use total PSS to count the total RAM usage of N processes.
To obtain these measurements, we use psutil.Process(PID).memory_maps()
which parses smaps
under the hood:
|
Then we create a MemoryMonitor
utility to measure and print the results for a list of PIDs.
The code is straightforward and can be found here.
Copy-on-read Overhead and "Memory Leak"¶
We start with a naive implementation of a dataset that produces items from a list:
|
Then we launch subprocesses to read from this dataset with the list of COCO data. To make a cleaner demo, we don't use PyTorch's dataloader, but just launch 4 subprocesses by ourselves:
|
We then added our MemoryMonitor
to it. The full code and its output logs are available on github. Each segment in the log contains memory measurements for the main process + 4 workers:
|
The code looks completely innocent. However, if we plot the memory usage of any dataloader worker over time, we seem to find a memory leak! This is the notorious "dataloader leaks memory" issue that is discussed at multiple places, e.g. this PyTorch issue and Edward's podcast.
In fact, the growth of RAM usage does stop in the end, so this issue is not a memory leak. But in reality, users often do not see the end before the system OOMs, and they may wrongly conclude this as a "memory leak".
The root cause of this issue is "copy-on-read" of forked CPython objects.
Copy-on-read of forked CPython objects¶
Linux has a copy-on-write mechanism: when a process forks, the child process will share its entire memory space with the parent, and only copy the relevant pages when necessary, i.e. when the child process needs to write to the page. This mechanism allows read-only pages to be shared to reduce total memory usage.
The copy-on-write behavior can be clearly observed in the above figure:
at time=0, the worker has 2.6G of shared RAM, 0 USS, and
However, this mechanism did not help us when we read our dataset. The problem is that our dataset is a large nested data structure that contains many small Python objects. Even though the dataset is "read-only" in theory, accessing any Python object will increment its refcount - causing a lot of memory writes. With these writes, memory can no longer be shared among parent and child processes. In other words, objects are not only copy-on-write, but also copy-on-read. Therefore, in the figure we see that the "Shared" RAM decreases and "USS" increases, since many pages are copied from shared memory into each process.
The end game is that each child process has to replicate all the pages that contain object refcounts in the dataset. For a dataset with many objects, this is almost the size of the dataset itself. In the output log, we see that this program uses 10G total PSS in the end, where each child process replicates 1.8G of USS.
Serialize to a Numpy Array¶
The copy-on-read issue is due to CPython's reference counting.
There are ways to change CPython's behavior, e.g. by gc.freeze
, but it has far-reaching consequences and I failed to make it work for the example here. However, there is a simple and transparent way to solve the issue: store the dataset with very few number of Python objects, so there are very few refcounts!
Below is a minimal implementation that stores a list
using 2 numpy arrays:
|
Detectron2 enables this type of serialization by default (since this commit by Yanghan). To compare different serialization mechanisms, we borrow its code into a serialization util, and use it here:
|
Just by this simple one-line change, the RAM usage greatly reduces. The end of the output log file is shown below.
|
We can see that:
- The total PSS usage is only 1.6G -- a 6x reduction.
- All processes have almost 0 USS, which means everything is shared! In fact, from the logs we can see that 1.6G is exactly the memory usage of the main process before starting subprocesses. Subprocesses add no extra memory usage.
- The reduction factor is better than
#processes
becausepickle.dumps
not only serializes but also compresses the data. We benefit from both sharing and compression by applying this optimization, at the cost of a tinypickle.loads
overhead in each access.
More on compression (not important)¶
Actually, after compression, the dataset only takes ~500M (printed at the beginning of log). So a question arises: why does the main process use 1.6G RAM before starting subprocesses?
This is in fact just an artifact of modern memory allocators: it does not always release memory back to the OS. In fact, if we run this simple serialization/compression code:
|
We see that we seem to "lose" ~700MB of RAM even after we've deleted everything:
|
Using a better allocator, e.g. by export LD_PRELOAD=libjemalloc.so
, can make this issue largely disappear.
This artifact is typically not a big concern, since allocators will find opportunities to reuse these free buffers.
(Well, they may be concerning in start_method="fork"
because reusing these free buffers may trigger copy-on-write!
But I'm not going to talk more about that.)
Pickle Overhead in Spawn / Forkserver¶
In our code above, we launched subprocesses using a start_method="fork"
argument.
"fork, spawn, forkserver" are the 3 "start methods" of Python's multiprocessing library. This article is a good reference that explains their differences.
Since start_method="fork"
is unsafe (in practice, it causes various crashes & deadlocks) and might no longer be the default in the future, we want to rerun our code above with start_method="spawn"
or "forkserver"
. Sadly, the serialized array is no longer shared among workers. Each worker has a large USS:
|
The reason why our trick no longer works is that "spawn" and "forkserver" don't benefit from the copy-on-write mechanism. They will start a "fresh" subprocess with fresh memory space, instead of sharing with the parent. Everything the child process needs to access is pickled in the parent process and sent to the child. This ensures safe behavior, but is bad for start-up speed and memory usage.
In our case, the entire dataset will be pickled and sent to child processes. This is why each child process consumes a large USS.
Serialize to a torch.Tensor
¶
It turns out there is a simple fix to this problem: just store the serialized dataset in a torch.Tensor
instead of a numpy array. The reason why it works, is that multiprocessing uses a customizable pickle implementation called ForkingPickler
, and PyTorch customizes how torch.Tensor
should be pickled by it: the tensor data will not be serialized to bytes. Instead, during pickling the tensor will be moved to shared memory files (typically under /dev/shm
) to be accessed by other processes directly.
To test tensor-based serialization, we run ./main-torchserialize.py spawn
using the code here, and observes the following memory usage in workers (raw log is here):
- "Shared_File" grows because workers will load from the shared
torch.Tensor
as needed. This is different fromstart_method="fork"
where the entire memory space is shared at the beginning. - "Shared_File" stops growing when the worker has accessed the entire shared tensor.
- The size of "Shared_File" in the end is roughly 500M
(size of serialized dataset) + 170M, where 170M is the
size of all the binary files that
import torch
needs to load such aslibtorch.so
. This can be easily verified by printing the measurements afterimport torch
.
After applying tensor-based serialization,
the total PSS usage in the end is 2.2G
-- still worse than our earlier number using start_method="fork"
.
Next section will optimize it further.
Per-Process Import Overhead¶
The last culprit in the above experiment is the 160MB
per-worker USS in the above figure: this is just the memory footprint of import torch
,
mainly for PyTorch's global variables, etc. Since every child process launched by "spawn / forkserver" is a "fresh" one, they all need to import torch
independently, hence each has 160MB of USS.
Luckily, "forkserver" provides a way to share the import torch
RAM usage through copy-on-write. By calling the undocumented Python API multiprocessing.set_forkserver_preload(["torch"])
before launching processes, each child process will be "less fresh": the torch library is preloaded (and shared), and don't need to be imported by each process independently.
Below are the experiment results. Code and full logs are on github:
|
- The total PSS is only 1.7G, which is roughly the same as our best number using
start_method="fork"
. - The USS of each worker is negligible, which means we've successfully shared everything. There is no per-worker memory overhead anymore.
(Note that this optimization may be unsafe if import torch
creates any threads.
My observation is that threads are indeed created due to import numpy
inside torch, but they can be disabled with environment variables.)
Share Datasets among Multiple GPU Processes¶
So far we've only looked at a single dataloader (with 4 workers). In reality, the only scalable way to use PyTorch on multiple GPUs is to use one process per GPU, each will have its own dataloader and dataloader workers. This gives a total of #GPUs x (#DL workers + 1)
processes organized like below:
We modified the previous experiment slightly into this code to run on 2 GPUs. The memory usage looks like this:
|
Our previous optimization on dataloader workers is still effective - dataloader workers have a tiny USS. However, RAM usage is now replicated by #GPUs times because we let each GPU worker read the dataset independently.
An inconvenient solution to this problem is to load and serialize the dataset before launching GPU workers. By doing this, all GPU workers share the dataset just like what dataloader workers do. However, this limits flexibility and often requires significant refactoring, due to reasons such as:
- Dataset would have to be made ready much earlier than usual
- Per-GPU data loading logic (e.g. sharding) may need to be modified
- Most launchers (e.g. torchrun, accelerate) don't support this at all
Another simple solution to this problem is again to use torch.Tensor
and ForkingPickler
to share the dataset among GPU workers, except that now we need to
manage the sharing explicitly like this:
|
This logic is implemented as another serialization util here. When using it as a drop-in replacement (full code here), the dataset is no longer replicated by GPU workers:
|
GPU worker 1 still has a small amount of extra USS, and that's just the footprint of import torch
that we saw earlier, and can be avoided using set_forkserver_preload
.
Note that the multiprocessing
library itself also provides shared memory support.
This PR contains an implementation of our serialization util without using PyTorch.
Summary¶
We've successfully reduced the total RAM usage by (approximately) a factor of
The essence of the solution is to let all processes share memory through a single torch.Tensor
object, which needs to be moved to Linux shared memory by PyTorch's custom pickling routine. The TLDR on how to achieve sharing is:
- Don't let dataloader workers access many Python objects in their parent. Serialize all objects into a single
torch.Tensor
(but not numpy array) for workers to access.- Don't let all GPU workers load data independently. Load in one GPU worker, and share with others through a
torch.Tensor
.
For list-like data, all of these can be implemented transparently using the serialization routines developed in this article.
Multi-processing is often the only way to achieve true parallelism in Python (until PEP703), but it comes with many tricky problems. This article hopefully provides an in-depth view of the problem of RAM usage.