of a collection about distributed AI throughout a number of GPUs:
Introduction
Within the earlier submit, we noticed how Distributed Information Parallelism (DDP) accelerates coaching by splitting batches throughout GPUs. DDP solves the throughput downside, nevertheless it introduces a brand new problem: reminiscence redundancy.
In vanilla DDP, each GPU holds a whole copy of the mannequin parameters, gradients, and optimizer states. For giant fashions like GPT-3 (175B parameters), this redundancy turns into an enormous waste of treasured VRAM.
ZeRO (Zero Redundancy Optimizer) solves this. There are three ranges:
- ZeRO-1 partitions solely optimizer states
- ZeRO-2 partitions optimizer states + gradients
- ZeRO-3 partitions optimizer states + gradients + mannequin parameters
ZeRO isn’t a parallelism method as a result of all GPUs nonetheless run the identical ahead and backward passes. It’s a reminiscence optimization technique that eliminates redundancy throughout GPUs, letting you prepare bigger fashions on the identical {hardware}.
The Reminiscence Drawback in DDP
Let’s break down what really consumes reminiscence throughout coaching. For a mannequin with parameters:
- Mannequin Parameters: values (the weights of your neural community)
- Gradients: values (one gradient per parameter)
- Optimizer States (Adam): values (first second and second second for every parameter)
- Activations: Intermediate outputs saved throughout ahead cross to be used in backward cross
The primary three scale with mannequin dimension and are redundant throughout GPUs in DDP. Activations scale with batch dimension, sequence size, and # neurons, and are distinctive per GPU since every GPU processes completely different knowledge. ZeRO doesn’t contact activation reminiscence.
Let’s calculate the reminiscence utilization for a 7B-parameter mannequin utilizing Adam and FP32:
- Parameters: 7 billion * 4 bytes = 28 GB
- Gradients: 7 billion * 4 bytes = 28 GB
- Optimizer states: 7 billion * 2 * 4 bytes = 56 GB
- Reminiscence per GPU in DDP: 112 GB
Activations add vital reminiscence on high of this, however since they’re distinctive per GPU, ZeRO can’t partition them. Strategies like activation checkpointing may help, it discards some activations after which recomputes them as wanted in the course of the backward cross. However that’s exterior the scope of this text.
Let’s perceive how ZeRO works by implementing it from the bottom up, beginning with ZeRO-1 and dealing our option to ZeRO-3.
ZeRO-1: Optimizer State Partitioning
In ZeRO-1, solely the optimizer states are partitioned. Every GPU:
- Nonetheless holds the full mannequin parameters and gradients
- Shops solely 1/N of the optimizer states (N = variety of GPUs)
- Updates solely the corresponding 1/N of the parameters
That is the sequence actions taken throughout coaching:
- Ahead cross: every GPU processes its personal micro-batch
- Backward cross: compute gradients
all-reducegradients: each GPU will get the all gradients- Optimizer step: Every GPU updates its parameter partition
all-gatherparameters: sync the up to date mannequin throughout GPUs

Right here’s a simplified implementation:
import torch
import torch.distributed as dist
class ZeRO_1:
def __init__(self, mannequin, optimizer_cls):
self.mannequin = mannequin
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.param_shards = record() # every rank holds solely its shard of the optimizer states
self.param_metadata = record() # metadata to reconstruct shards
for param in self.mannequin.parameters():
original_shape = param.knowledge.form
flat = param.knowledge.view(-1)
numel = flat.numel()
the rest = numel % self.world_size
pad_size = (self.world_size - the rest) % self.world_size
padded_numel = numel + pad_size
shard_size = padded_numel // self.world_size
shard_start = self.rank * shard_size
shard_end = shard_start + shard_size
self.param_metadata.append(
{
"original_shape": original_shape,
"numel": numel,
"padded_numel": padded_numel,
"shard_size": shard_size,
"shard_start": shard_start,
"shard_end": shard_end,
}
)
if pad_size > 0:
flat_padded = torch.cat([flat, flat.new_zeros(pad_size)])
else:
flat_padded = flat
shard = flat_padded[shard_start:shard_end].clone()
shard.requires_grad_(True)
self.param_shards.append(shard)
self.optimizer = optimizer_cls(self.param_shards)
def training_step(self, inputs, targets, loss_fn):
output = self.mannequin(inputs) # ahead
loss = loss_fn(output, targets) # compute loss
loss.backward() # backward
self._sync_gradients() # all-reduce gradients throughout GPUs
self.optimizer.step() # replace native shard of parameters
self._sync_params() # all collect mannequin params
# clear gradients for the subsequent step
for param in self.mannequin.parameters():
param.grad = None
def _sync_gradients(self):
for idx, param in enumerate(self.mannequin.parameters()):
meta = self.param_metadata[idx]
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
param.grad /= self.world_size
self.param_shards[idx].grad = param.grad.view(-1)[meta["shard_start"]:meta["shard_end"]]
def _sync_params(self):
for idx, param in enumerate(self.mannequin.parameters()):
meta = self.param_metadata[idx]
full_flat = torch.empty(meta["padded_numel"], gadget=param.gadget, dtype=param.dtype)
dist.all_gather_into_tensor(
output_tensor=full_flat,
input_tensor=self.param_shards[idx].knowledge,
)
reconstructed = full_flat[:meta["numel"]].view(meta["original_shape"])
param.knowledge.copy_(reconstructed)
Discover that the all-reduce syncs all gradients, however every GPU solely makes use of the gradients for its personal parameter partition, it’s overcommunicating. ZeRO-2 fixes this by sharding the gradients too.
In follow, you’d by no means use ZeRO-1 as ZeRO-2 provides you higher reminiscence financial savings at primarily the identical value. Nevertheless it’s nonetheless price going over it for studying functions.
Reminiscence with ZeRO-1, 7B mannequin, 8 GPUs:
- Parameters: 28 GB (totally replicated)
- Gradients: 28 GB (totally replicated)
- Optimizer states: 56 GB / 8 = 7 GB
- Complete per GPU: 63 GB (down from GB)
ZeRO-2: Gradient Partitioning
ZeRO-2 partitions each optimizer states and gradients. Since every GPU solely updates a partition of parameters, it solely wants the corresponding gradients.
ZeRO-1 makes use of all-reduce, which supplies each GPU all of the gradients. ZeRO-2 replaces this with reduce-scatter, every GPU receives solely the gradients it really wants. This protects each reminiscence and communication bandwidth.
Coaching steps:
- Ahead cross: every GPU processes its personal micro-batch
- Backward cross: compute gradients
reduce-scattergradients: every GPU will get solely its partition- Optimizer step: Every GPU updates its parameter partition
all-gatherparameters: sync the up to date mannequin throughout GPUs

The implementation is similar to ZeRO-1, however the gradient synchronization step makes use of reduce-scatter as a substitute of all-reduce:
However wait, if each GPU computes all gradients throughout backprop, how does this really save VRAM? Right here’s how:
- Because the parameter gradients are computed layer by layer, they’re instantly
reduce-scatteredand the native copy is freed (our simplified implementation doesn’t carry out this). - Throughout backprop, you solely want the gradient of the subsequent neuron activation to compute the present param’s gradient, i.e., you don’t want your complete gradient graph.
- That method you possibly can unencumber the reminiscence for gradients as you’re shifting backwards, maintaining solely the assigned partition for every GPU.
Reminiscence with ZeRO-2, 7B mannequin, 8 GPUs:
- Parameters: 28 GB (totally replicated)
- Gradients: 28 GB / 8 = 3.5 GB
- Optimizer states: 56 GB / 8 = 7 GB
- Complete per GPU: 38.5 GB (down from 112 GB)
ZeRO-3: Parameter Partitioning
ZeRO-3 partitions optimizer states, gradients, and parameters. Every GPU shops only one/N of your complete mannequin state.
Throughout ahead and backward passes, every layer wants its full parameters, however every GPU solely shops a fraction. So we all-gather parameters just-in-time, use them, then discard instantly after.
Coaching steps:
- Ahead cross:
- All-gather the layer’s parameters from all GPUs
- Run the layer’s ahead cross utilizing earlier layer’s activations as enter
- Discard the gathered parameters (maintain solely the native partition)
- Repeat these steps till all layers are completed
- Backward cross (per layer, in reverse):
- All-gather the layer’s parameters once more
- Compute gradients for present layer utilizing activation gradients from subsequent layer
- Scale back-scatter the gradients (every GPU retains its shard)
- Discard the gathered parameters (maintain solely the native partition)
- Repeat these steps till all layers are completed
- Every GPU runs an optimizer step on its partition
- No remaining all-gather wanted since parameters are gathered layer-by-layer in the course of the ahead cross

Right here’s a simplified implementation:
class ZeRO_3(ZeRO_2):
"""
ZeRO-3: Shard optimizer states (stage 1) + gradients (stage 2) + mannequin parameters (stage 3).
At relaxation, every rank holds solely param_shards[idx] — a 1/world_size slice
of every parameter. Full parameters are materialised quickly throughout
the ahead and backward passes through all_gather, then instantly freed.
"""
def __init__(self, mannequin, optimizer_cls):
self.mannequin = mannequin
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.param_metadata = []
shard_list = []
self._param_to_idx = {}
for idx, param in enumerate(self.mannequin.parameters()):
original_shape = param.knowledge.form
flat = param.knowledge.view(-1)
numel = flat.numel()
the rest = numel % self.world_size
pad_size = (self.world_size - the rest) % self.world_size
padded_numel = numel + pad_size
shard_size = padded_numel // self.world_size
shard_start = self.rank * shard_size
shard_end = shard_start + shard_size
self.param_metadata.append(
{
"original_shape": original_shape,
"numel": numel,
"padded_numel": padded_numel,
"shard_size": shard_size,
"shard_start": shard_start,
"shard_end": shard_end,
}
)
if pad_size > 0:
flat_padded = torch.cat([flat, flat.new_zeros(pad_size)])
else:
flat_padded = flat
shard = flat_padded[shard_start:shard_end].clone()
shard_list.append(shard)
# Substitute the total tensor with solely this rank's shard.
# The mannequin's param.knowledge now factors to a tiny slice; the total
# weight will likely be reconstructed on demand throughout ahead/backward.
param.knowledge = shard.detach()
self._param_to_idx[param] = idx
self.param_shards = [s.requires_grad_(True) for s in shard_list]
self.optimizer = optimizer_cls(self.param_shards)
self._register_hooks()
def _gather_param(self, idx, gadget, dtype):
"""All-gather the total parameter tensor for parameter `idx`."""
meta = self.param_metadata[idx]
full_flat = torch.empty(meta["padded_numel"], gadget=gadget, dtype=dtype)
dist.all_gather_into_tensor(
output_tensor=full_flat,
input_tensor=self.param_shards[idx].knowledge,
)
return full_flat[: meta["numel"]].view(meta["original_shape"])
def _gather_module_params(self, module):
"""Collect full params for each parameter that belongs to this module solely (not youngsters)."""
for param in module.parameters(recurse=False):
idx = self._param_to_idx[param]
param.knowledge = self._gather_param(idx, param.gadget, param.dtype)
def _reshard_module_params(self, module):
"""Reshard params again to native shard for each direct param of this module."""
for param in module.parameters(recurse=False):
idx = self._param_to_idx[param]
param.knowledge = self.param_shards[idx].knowledge
def _register_hooks(self):
self._hooks = []
for module in self.mannequin.modules():
# Skip container modules that don't have any direct parameters
if not record(module.parameters(recurse=False)):
proceed
# Ahead: collect -> run -> reshard
h1 = module.register_forward_pre_hook(
lambda mod, _inputs: self._gather_module_params(mod)
)
h2 = module.register_forward_hook(
lambda mod, _inputs, _output: self._reshard_module_params(mod)
)
# Backward: collect earlier than grad computation → reshard after
h3 = module.register_full_backward_pre_hook(
lambda mod, _grad_output: self._gather_module_params(mod)
)
h4 = module.register_full_backward_hook(
lambda mod, _grad_input, _grad_output: self._reshard_module_params(mod)
)
self._hooks.prolong([h1, h2, h3, h4])
def training_step(self, inputs, targets, loss_fn):
# Hooks deal with all collect/reshard round every module routinely
output = self.mannequin(inputs)
loss = loss_fn(output, targets)
loss.backward()
self._sync_gradients()
# Every rank updates solely its native shard
self.optimizer.step()
for param in self.mannequin.parameters():
param.grad = None
Every layer’s parameters are gathered proper earlier than they’re wanted and freed instantly after. This retains peak reminiscence minimal at the price of extra communication. In follow, implementations overlap the all-gather for layer N+1 with the ahead of layer N to cover latency.
Reminiscence with ZeRO-3, 7B mannequin, 8 GPUs:
- Parameters: 28 GB / 8 = 3.5 GB
- Gradients: 28 GB / 8 = 3.5 GB
- Optimizer states: 56 GB / 8 = 7 GB
- Complete per GPU: 14 GB (down from 112 GB)
That’s an 8x discount in reminiscence utilization, which is precisely what we’d count on from partitioning throughout 8 GPUs.
Utilizing ZeRO in PyTorch
PyTorch ships with two implementations of ZeRO-3: FSDP1 (older, much less optimized) and FSDP2 (newer, beneficial). All the time use FSDP2.
FSDP (Absolutely Sharded Information Parallel) handles parameter gathering, gradient scattering, communication overlap, and reminiscence administration routinely:
from torch.distributed.fsdp import fully_shard
mannequin = Transformer()
for layer in mannequin.layers:
fully_shard(layer)
fully_shard(mannequin)
You need to apply fully_shard layer-by-layer after which wrap the entire mannequin.
Conclusion
ZeRO is exchanging reminiscence for communication, so it’s not a free lunch. Basically it’s not price it for smaller fashions (e.g. BERT) nevertheless it’s a recreation changer for bigger fashions.
Congratulations on making it to the tip! On this submit, you realized about:
- The reminiscence redundancy downside in commonplace DDP
- How ZeRO partitions optimizer states, gradients, and parameters throughout GPUs
- The three levels of ZeRO and their reminiscence/communication trade-offs
- How you can use ZeRO-3 through PyTorch’s FSDP
Within the subsequent article, we’ll discover Tensor Parallelism, a mannequin parallelism method that accelerates a layer computation by distributing work throughout GPUs.
