ConceptMachine Learning

Large Model Checkpointing

There are various aspects to optimize when training large models. It often lasts weeks and involves managing billions of rows of data, with checkpoints of such models weighing terabytes in some cases. Let’s explore how to handle large model checkpointing.

To make your processes around checkpoints more efficient, we’ll discuss multiple strategies — including asynchronous checkpointing, choosing the proper storage and format, adjusting your code to the network parameters, and scheduling with possible redoing kept in mind. But first, there are some basics to consider.

Model Checkpointing in a Nutshell

A model checkpoint saves a model at a specific time, including metadata (often JSON or YAML) and the training state in binary format. The training state includes model parameters and optimizer state. Checkpoints allow recovery from failures and save the trained model.

Saving and loading checkpoints is a crucial component of any ML framework. For instance, PyTorch utilizes torch.save to save a Pickle-serialized dictionary containing the model’s weights and torch.load to retrieve them.

TensorFlow offers several IO formats, such as .keras, .ckpt, and .h5. Flax, a deep learning framework based on JAX, recommends using orbax.checkpoint for checkpointing, but there is also flax.serialization for basic serialization tasks.

What Is So Special About Large Models Checkpointing

Checkpoint Size and Model Parameters

The number of parameters in a neural network model plays a crucial role in determining its checkpoint size. Each parameter is typically stored as a 32-bit floating-point number, occupying 4 bytes of memory.

However, during inference, model parameters can be saved as 16-bit data types like float16 or bfloat16, reducing the memory footprint to 2 bytes per parameter without affecting prediction accuracy. Advanced optimizers like Adam introduce additional statistics for each parameter, increasing the checkpoint size further.

Classic vs. Deep Learning Models

While classic machine learning models like logistic regression or gradient boosting have only hundreds of thousands of parameters, resulting in checkpoint sizes of a few megabytes, deep learning models can have millions or even billions of parameters.

Large language models such as GPT-3 and LLaMa range from 1.3 billion to 175 billion parameters, with checkpoint sizes reaching gigabytes or even terabytes in training format.

LLaMa Checkpoint Size Calculations

For instance, the LLaMa 7B model’s checkpoint in inference format with float16 data type is approximately 14 gigabytes, while its training format with optimizer state and float32 parameters is around 69 gigabytes. The LLaMa 70B model’s training checkpoint is a massive 782 gigabytes.

Storage and Loading Challenges

These large checkpoints create significant storage and loading challenges, especially when training on multiple GPUs or clusters, necessitating specialized logic for checkpoint management.

How to Handle Large Model Checkpointing

Use async checkpointing

Throughout training, model parameters reside in GPU memory (VRAM), but saving a model to storage is a CPU task, requiring the parameters to be stored in RAM. Two steps are necessary: offloading parameters from VRAM to RAM (fast), and saving from RAM to storage (slower, depending on model size and storage throughput).

The trick is to proceed with model training on the GPU while the CPU saves the checkpoint to storage in the background – asynchronous checkpointing. It has been implemented in frameworks like AsyncCheckpointIO (Torch-Lightning) and AsyncCheckpointer (orbax.checkpoint).

Note: host RAM size should accommodate the checkpoint, and wait for background saving to finish at the end of training.

Know your storage

Cloud providers offer various storage services: network disks, network file storage (NFS), object storage. Though ultimately written to physical disks, each option has pros, cons, and different cloud logic behind it.

Network disks

Network disks appear as regular disks but are remote network services. They can’t be mounted to multiple VMs simultaneously, so can’t share checkpoints out-of-the-box between hosts. NFS is an alternative for shared file systems.

Network file system

NFS is a protocol, with actual implementation depending on the provider. One can use higher block sizes and SSD disks for better NFS performance, or deploy self-managed NFS (e.g., GlusterFS) on network disks and VMs.

S3-compatible object storage

Object storage is not a file system!

It stores a mapping of ‘string key -> binary data.’ Directories are for user convenience, not reflecting actual structure. Operations like atomic move (mv) or overwriting bytes in place are unsupported. Check documentation for advanced logic in specific implementations.

Parallel IO

Object storage optimizes for parallel reads and writes, supporting random byte ranges. Ensure your code performs parallel IO (e.g., aws s3 cp, boto3’s client.copy).

s3fs

Tools like s3fs, Goofys, mountpoint-s3 mount S3 buckets as volumes, abstracting away S3 specifics. However:

Standard mv is slow, avoid it.
cp may underperform S3-specific tools unless s3fs does parallel IO with optimal mount parameters.
directory.exists() is ambiguous, defined by s3fs implementation.
Metadata requests (ls, find) may be slow due to S3’s key-value nature, avoid abusing them.


Checkpoint Format Matters


Checkpoints can be laid out on disk in various ways – e.g., one serialized blob (torch.save) or one file per layer. Reading 10GB as one file differs from reading 100 0.1GB files. One large blob makes it hard for each host to only read the tensors it needs in multi-host, model-parallel training. Granular, sharded formats like OCDBT combine small blobs into larger files, striking a balance.

Benefit from the Network in Multi-Host Training

In data-parallel multi-host training, all hosts store the same model, so initially retrieving the checkpoint generates heavy load on storage if all hosts read it simultaneously.

  • With fixed-bandwidth storage, all readers share that bandwidth, so each host reads much slower (e.g., 32x for 32 hosts).
  • With scalable storage, reading speed remains constant under load.
  • If storage can’t scale well, one host can read the checkpoint into memory, then share it with others via the inter-cluster network using techniques like NCCL Broadcast.
  • With GPU clusters and InfiniBand, checkpoints can even load directly into GPU memory and transfer between hosts over ultra-fast GPU-to-GPU links.
  • For sharded checkpoints, each host reads only its shard, then NCCL AllGather combines them across hosts. orbax.checkpoint has multi-host IO optimizations. The same optimization applies to multi-host saving – each host only saves its shard portion.

Choose a Sane Checkpointing Schedule


A schedule defines logic like “save a checkpoint every N steps and keep the last M checkpoints.”

Checkpointing frequency: If training crashes, you’ll need to redo steps after restarting from the last checkpoint. Frequency represents a trade-off between redoing more steps (if infrequent) and longer training time from GPU stalls during frequent saving.

Storing multiple checkpoints: Keeping some previous checkpoints is useful, e.g., in case of gradient explosions where you need to tweak hyperparameters and restart from an older checkpoint. Five to ten recent checkpoints plus a few older ones is often sufficient, as storage costs are lower than GPUs but not negligible.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses User Verification plugin to reduce spam. See how your comment data is processed.