Demystify RAM Usage in Multi-Process Data Loaders

A typical PyTorch training program on 8 GPUs with 4 dataloader workers per GPU would create at least processes. A naive use of PyTorch dataset and dataloader can easily replicate your dataset's RAM usage by 40 times. This issue has probably affected everyone who has done anything nontrivial with PyTorch. In this post, we will explain why it happens, and how to avoid the 40x RAM usage.

All code examples and experiment results are available on github at ppwwyyxx/RAM-multiprocess-dataloader. The content is not specific to PyTorch: it applies to any user of Python's multiprocessing library on Linux.

Motivation for In-RAM Data

Datasets for machine learning are usually not stored in RAM. But it's common to store their "metadata" in RAM, and this may still cause nontrivial RAM usage. The metadata could be:

  • For ImageNet dataset: A million file names and their labels.
  • For COCO dataset: 100k file names and their bounding boxes, segmentations, etc.

As a concrete case, loading the metadata of COCO training set into Python takes ~2.4G of RAM:

# Download from https://huggingface.co/datasets/merve/coco/resolve/main/annotations/instances_train2017.json
def create_coco() -> list[Any]:
with open("instances_train2017.json") as f:
obj = json.load(f)
return obj["annotations"]

We obviously don't want to replicate this 2.4G of RAM across all processes.

In-RAM metadata is needed for flexibility

We acknowledge that there are ways to offload these metadata to disk. For example, people sometimes do:

  • Store all the metadata together with raw data on disk, so metadata are not stored in RAM.
  • Read sequentially from a combined single-file dataset, so that file names or indices are not stored in RAM.

By doing these, the RAM usage of a dataset becomes negligible. However, these methods will sacrifice flexibility and capabilities, such as random-access, perfect shuffle, merging datasets arbitrarily, custom subsampling support, etc. Notably, PyTorch's commonly used map-style datasets support random access & sampling. All of these capabilities require certain metadata in RAM.

This article ignores any of these offloading methods. Instead, we'll discuss how to reduce the RAM usage without moving these data out of RAM. The idea is simple: we'll try to let all processes share a single copy of the dataset.

Measure RAM Usage

First let's build tools to measure RAM usage - which is not as easy as it sounds.

Common tools like top -p PID or psutil.Process(PID).memory_info() obtains memory statistics from /proc/{PID}/statm or /proc/{PID}/status, but they are insufficient for our analysis. Instead, we'll use the information provided in

  • /proc/{PID}/smaps: per-memory-mapping RAM usage information, documented in this man page
  • /proc/{PID}/smaps_rollup: aggregation of data from smaps

We'll derive the following important measurements from it:

  • USS (Unique Set Size): RAM that's unique/private to this process, i.e. not shared with any other process. This is obtained by the sum of "private_*" entries in smaps.
  • Shared: RAM in this process that's also shared with other processes. This is obtained by the sum of "shared_*" entries in smaps.
  • Shared_File: RAM that's shared with other processes through files. It should be no larger than "Shared".
    • This number should be almost the same as the "SHR" column in top/htop.
  • RSS (Resident Set Size): All memory that this process holds in RAM. RSS = USS + Shared.
  • PSS (Proportional Set Size): Like RSS, but it avoids overcounting shared memory multiple times across all processes that are sharing it. It's basically "USS + Shared / (number of processes sharing it)". By definition, we should use total PSS to count the total RAM usage of N processes.

To obtain these measurements, we use psutil.Process(PID).memory_maps() which parses smaps under the hood:

def get_mem_info(pid: int) -> dict[str, int]:
res = defaultdict(int)
for mmap in psutil.Process(pid).memory_maps():
res['rss'] += mmap.rss
res['pss'] += mmap.pss
res['uss'] += mmap.private_clean + mmap.private_dirty
res['shared'] += mmap.shared_clean + mmap.shared_dirty
if mmap.path.startswith('/'): # looks like a file path
res['shared_file'] += mmap.shared_clean + mmap.shared_dirty
return res

Then we create a MemoryMonitor utility to measure and print the results for a list of PIDs. The code is straightforward and can be found here.

Copy-on-read Overhead and "Memory Leak"

We start with a naive implementation of a dataset that produces items from a list:

class NaiveDatasetFromList(torch.utils.data.Dataset):
def __init__(self, lst):
self.lst = lst
def __len__(self):
return len(self.lst)
def __getitem__(self, idx: int):
return self.lst[idx]

Then we launch subprocesses to read from this dataset with the list of COCO data. To make a cleaner demo, we don't use PyTorch's dataloader, but just launch 4 subprocesses by ourselves:

def worker(_, dataset: torch.utils.data.Dataset):
while True:
for sample in dataset:
# read the data, with a fake latency
time.sleep(0.000001)
result = pickle.dumps(sample)

if __name__ == "__main__":
ds = NaiveDatasetFromList(create_coco())
ctx = torch.multiprocessing.start_processes(
worker, (ds, ), nprocs=4, join=False, daemon=True, start_method='fork')

We then added our MemoryMonitor to it. The full code and its output logs are available on github. Each segment in the log contains memory measurements for the main process + 4 workers:

$ ./main-naive.py
time PID rss pss uss shared shared_file
------ ------ ----- ----- ----- -------- -------------
34724 791339 2.7G 2.0G 1.8G 993.8M 163.5M
34724 791625 2.6G 1.9G 1.8G 848.6M 16.4M
34724 791626 2.6G 1.9G 1.8G 848.6M 16.4M
34724 791627 2.6G 1.9G 1.8G 848.6M 16.4M
34724 791628 2.6G 1.9G 1.8G 848.8M 16.5M

The code looks completely innocent. However, if we plot the memory usage of any dataloader worker over time, we seem to find a memory leak! This is the notorious "dataloader leaks memory" issue that is discussed at multiple places, e.g. this PyTorch issue and Edward's podcast.

In fact, the growth of RAM usage does stop in the end, so this issue is not a memory leak. But in reality, users often do not see the end before the system OOMs, and they may wrongly conclude this as a "memory leak".

The root cause of this issue is "copy-on-read" of forked CPython objects.

Copy-on-read of forked CPython objects

Linux has a copy-on-write mechanism: when a process forks, the child process will share its entire memory space with the parent, and only copy the relevant pages when necessary, i.e. when the child process needs to write to the page. This mechanism allows read-only pages to be shared to reduce total memory usage.

The copy-on-write behavior can be clearly observed in the above figure: at time=0, the worker has 2.6G of shared RAM, 0 USS, and of PSS because the RAM is shared among 5 processes (4 workers + 1 main).

However, this mechanism did not help us when we read our dataset. The problem is that our dataset is a large nested data structure that contains many small Python objects. Even though the dataset is "read-only" in theory, accessing any Python object will increment its refcount - causing a lot of memory writes. With these writes, memory can no longer be shared among parent and child processes. In other words, objects are not only copy-on-write, but also copy-on-read. Therefore, in the figure we see that the "Shared" RAM decreases and "USS" increases, since many pages are copied from shared memory into each process.

The end game is that each child process has to replicate all the pages that contain object refcounts in the dataset. For a dataset with many objects, this is almost the size of the dataset itself. In the output log, we see that this program uses 10G total PSS in the end, where each child process replicates 1.8G of USS.

Serialize to a Numpy Array

The copy-on-read issue is due to CPython's reference counting. There are ways to change CPython's behavior, e.g. by gc.freeze, but it has far-reaching consequences and I failed to make it work for the example here. However, there is a simple and transparent way to solve the issue: store the dataset with very few number of Python objects, so there are very few refcounts! Below is a minimal implementation that stores a list using 2 numpy arrays:

class NumpySerializedList:
def __init__(self, lst: list[Any]):
lst = [np.frombuffer(pickle.dumps(x), dtype=np.uint8) for x in lst]
self._addr = np.cumsum([len(x) for x in lst])
self._lst = np.concatenate(lst)

def __len__(self):
return len(self._addr)

def __getitem__(self, idx: int):
start = 0 if idx == 0 else self._addr[idx - 1]
end = self._addr[idx]
return pickle.loads(memoryview(self._lst[start:end]))

Detectron2 enables this type of serialization by default (since this commit by Yanghan). To compare different serialization mechanisms, we borrow its code into a serialization util, and use it here:

- ds = NaiveDatasetFromList(create_coco())
+ from serialize import NumpySerializedList
+ ds = NaiveDatasetFromList(NumpySerializedList(create_coco())

Just by this simple one-line change, the RAM usage greatly reduces. The end of the output log file is shown below.

$ ./main-numpyserialize.py
PID rss pss uss shared shared_file
------ ----- ------ ----- -------- -------------
877767 1.6G 396.3M 20.2M 1.6G 184.8M
877901 1.5G 306.5M 3.8M 1.5G 22.3M
877902 1.5G 306.5M 3.7M 1.5G 22.3M
877903 1.5G 306.6M 3.9M 1.5G 22.3M
877904 1.5G 306.4M 3.6M 1.5G 22.3M

We can see that:

  • The total PSS usage is only 1.6G -- a 6x reduction.
  • All processes have almost 0 USS, which means everything is shared! In fact, from the logs we can see that 1.6G is exactly the memory usage of the main process before starting subprocesses. Subprocesses add no extra memory usage.
  • The reduction factor is better than #processes because pickle.dumps not only serializes but also compresses the data. We benefit from both sharing and compression by applying this optimization, at the cost of a tiny pickle.loads overhead in each access.

More on compression (not important)

Actually, after compression, the dataset only takes ~500M (printed at the beginning of log). So a question arises: why does the main process use 1.6G RAM before starting subprocesses?

This is in fact just an artifact of modern memory allocators: it does not always release memory back to the OS. In fact, if we run this simple serialization/compression code:

monitor = MemoryMonitor()
print("Initial", monitor.str())
lst = create_coco()
print("JSON", monitor.str())
lst = NumpySerializedList(lst)
print("Serialized", monitor.str())
del lst; import gc; gc.collect()
print("End", monitor.str())

We see that we seem to "lose" ~700MB of RAM even after we've deleted everything:

Initial PID=1156792, rss=328.7M, pss=238.7M, uss=161.4M, shared=167.3M, shared_file=167.3M
JSON PID=1156792, rss=2.8G, pss=2.7G, uss=2.6G, shared=167.3M, shared_file=167.3M
Serialized PID=1156792, rss=1.6G, pss=1.5G, uss=1.5G, shared=167.3M, shared_file=167.3M
End PID=1156792, rss=1.1G, pss=1.0G, uss=986.2M, shared=167.3M, shared_file=167.3M

Using a better allocator, e.g. by export LD_PRELOAD=libjemalloc.so, can make this issue largely disappear.

This artifact is typically not a big concern, since allocators will find opportunities to reuse these free buffers. (Well, they may be concerning in start_method="fork" because reusing these free buffers may trigger copy-on-write! But I'm not going to talk more about that.)

Pickle Overhead in Spawn / Forkserver

In our code above, we launched subprocesses using a start_method="fork" argument. "fork, spawn, forkserver" are the 3 "start methods" of Python's multiprocessing library. This article is a good reference that explains their differences.

Since start_method="fork" is unsafe (in practice, it causes various crashes & deadlocks) and might no longer be the default in the future, we want to rerun our code above with start_method="spawn" or "forkserver". Sadly, the serialized array is no longer shared among workers. Each worker has a large USS:

$ ./main-numpyserialize.py spawn
PID rss pss uss shared shared_file
------- ------ ------ ------ -------- -------------
1177291 1.6G 1.5G 1.5G 168.7M 168.7M
1177405 840.8M 698.3M 672.1M 168.7M 168.7M
1177419 840.9M 698.3M 672.1M 168.8M 168.8M
1177443 840.7M 698.3M 672.1M 168.6M 168.6M
1177456 840.6M 698.5M 672.2M 168.4M 168.4M

The reason why our trick no longer works is that "spawn" and "forkserver" don't benefit from the copy-on-write mechanism. They will start a "fresh" subprocess with fresh memory space, instead of sharing with the parent. Everything the child process needs to access is pickled in the parent process and sent to the child. This ensures safe behavior, but is bad for start-up speed and memory usage.

In our case, the entire dataset will be pickled and sent to child processes. This is why each child process consumes a large USS.

Serialize to a torch.Tensor

It turns out there is a simple fix to this problem: just store the serialized dataset in a torch.Tensor instead of a numpy array. The reason why it works, is that multiprocessing uses a customizable pickle implementation called ForkingPickler, and PyTorch customizes how torch.Tensor should be pickled by it: the tensor data will not be serialized to bytes. Instead, during pickling the tensor will be moved to shared memory files (typically under /dev/shm) to be accessed by other processes directly.

To test tensor-based serialization, we run ./main-torchserialize.py spawn using the code here, and observes the following memory usage in workers (raw log is here):

  • "Shared_File" grows because workers will load from the shared torch.Tensor as needed. This is different from start_method="fork" where the entire memory space is shared at the beginning.
  • "Shared_File" stops growing when the worker has accessed the entire shared tensor.
  • The size of "Shared_File" in the end is roughly 500M (size of serialized dataset) + 170M, where 170M is the size of all the binary files that import torch needs to load such as libtorch.so. This can be easily verified by printing the measurements after import torch.

After applying tensor-based serialization, the total PSS usage in the end is 2.2G -- still worse than our earlier number using start_method="fork". Next section will optimize it further.

Per-Process Import Overhead

The last culprit in the above experiment is the 160MB per-worker USS in the above figure: this is just the memory footprint of import torch, mainly for PyTorch's global variables, etc. Since every child process launched by "spawn / forkserver" is a "fresh" one, they all need to import torch independently, hence each has 160MB of USS.

Luckily, "forkserver" provides a way to share the import torch RAM usage through copy-on-write. By calling the undocumented Python API multiprocessing.set_forkserver_preload(["torch"]) before launching processes, each child process will be "less fresh": the torch library is preloaded (and shared), and don't need to be imported by each process independently.

Below are the experiment results. Code and full logs are on github:

$ ./main-torchserialize.py forkserver
PID rss pss uss shared shared_file
------- ------ ------ ------ -------- -------------
1204121 1.6G 1.1G 988.6M 681.5M 681.5M
1204230 707.7M 152.1M 16.9M 690.9M 559.5M
1204231 707.7M 152.2M 16.9M 690.9M 559.5M
1204232 707.7M 152.1M 16.8M 690.9M 559.5M
1204233 707.7M 152.1M 16.8M 691.0M 559.5M
  • The total PSS is only 1.7G, which is roughly the same as our best number using start_method="fork".
  • The USS of each worker is negligible, which means we've successfully shared everything. There is no per-worker memory overhead anymore.

(Note that this optimization may be unsafe if import torch creates any threads. My observation is that threads are indeed created due to import numpy inside torch, but they can be disabled with environment variables.)

Share Datasets among Multiple GPU Processes

So far we've only looked at a single dataloader (with 4 workers). In reality, the only scalable way to use PyTorch on multiple GPUs is to use one process per GPU, each will have its own dataloader and dataloader workers. This gives a total of #GPUs x (#DL workers + 1) processes organized like below:

We modified the previous experiment slightly into this code to run on 2 GPUs. The memory usage looks like this:

$ ./main-multigpu-naive.py
PID rss pss uss shared shared_file
------- ------ ------ ------- -------- -------------
1495766 1.7G 1.1G 1017.0M 694.2M 694.2M # GPU worker 0
1495938 757.7M 198.5M 67.8M 689.9M 580.7M
1495939 757.6M 198.5M 67.8M 689.8M 580.6M
1495940 757.7M 198.5M 67.8M 689.9M 580.6M
1495941 757.7M 198.5M 67.8M 689.9M 580.7M
1495767 1.7G 1.1G 1015.9M 693.9M 693.9M # GPU worker 1
1495934 757.9M 198.5M 67.7M 690.1M 580.8M
1495935 757.7M 198.4M 67.7M 690.0M 580.6M
1495936 757.7M 198.4M 67.7M 690.0M 580.6M
1495937 757.9M 198.4M 67.7M 690.2M 580.8M

Our previous optimization on dataloader workers is still effective - dataloader workers have a tiny USS. However, RAM usage is now replicated by #GPUs times because we let each GPU worker read the dataset independently.

An inconvenient solution to this problem is to load and serialize the dataset before launching GPU workers. By doing this, all GPU workers share the dataset just like what dataloader workers do. However, this limits flexibility and often requires significant refactoring, due to reasons such as:

  • Dataset would have to be made ready much earlier than usual
  • Per-GPU data loading logic (e.g. sharding) may need to be modified
  • Most launchers (e.g. torchrun, accelerate) don't support this at all

Another simple solution to this problem is again to use torch.Tensor and ForkingPickler to share the dataset among GPU workers, except that now we need to manage the sharing explicitly like this:

if comm.get_local_rank() == 0:  # GPU0 reads data and moves it to shared memory.
# Move data to shared memory, obtain a handle to send to each local worker.
handles = [None] + [
bytes(mp.reduction.ForkingPickler.dumps(tensor_dataset))
for _ in range(comm.get_local_size() - 1)]
else:
handles = None
# Each GPU receives its handle from GPU0.
handle = local_scatter(handles)

if comm.get_local_rank() > 0:
# Materialize a tensor from shared memory.
tensor_dataset = ForkingPickler.loads(handle)

This logic is implemented as another serialization util here. When using it as a drop-in replacement (full code here), the dataset is no longer replicated by GPU workers:

$ ./main-multigpu-sharedmem.py
PID rss pss uss shared shared_file
------- ------ ------ ------- -------- -------------
1533910 1.7G 1.1G 1015.4M 693.4M 693.4M # GPU worker 0
1534032 757.9M 152.9M 67.9M 690.0M 580.8M
1534033 757.9M 152.9M 67.9M 690.0M 580.8M
1534034 757.9M 152.9M 67.9M 690.0M 580.8M
1534035 757.9M 152.9M 67.9M 690.0M 580.8M
1533911 374.2M 220.0M 192.6M 181.6M 181.6M # GPU worker 1
1534036 757.8M 152.7M 67.7M 690.1M 580.7M
1534037 757.8M 152.7M 67.6M 690.1M 580.7M
1534038 757.8M 152.7M 67.6M 690.2M 580.7M
1534039 757.8M 152.7M 67.6M 690.2M 580.7M

GPU worker 1 still has a small amount of extra USS, and that's just the footprint of import torch that we saw earlier, and can be avoided using set_forkserver_preload.

Note that the multiprocessing library itself also provides shared memory support. This PR contains an implementation of our serialization util without using PyTorch.

Summary

We've successfully reduced the total RAM usage by (approximately) a factor of

The essence of the solution is to let all processes share memory through a single torch.Tensor object, which needs to be moved to Linux shared memory by PyTorch's custom pickling routine. The TLDR on how to achieve sharing is:

  1. Don't let dataloader workers access many Python objects in their parent. Serialize all objects into a single torch.Tensor (but not numpy array) for workers to access.
  2. Don't let all GPU workers load data independently. Load in one GPU worker, and share with others through a torch.Tensor.

For list-like data, all of these can be implemented transparently using the serialization routines developed in this article.

Multi-processing is often the only way to achieve true parallelism in Python (until PEP703), but it comes with many tricky problems. This article hopefully provides an in-depth view of the problem of RAM usage.

Comments