Open In App

Fully Sharded Data Parallel (FSDP)

Last Updated : 23 Jul, 2025
Comments
Improve
Suggest changes
Like Article
Like
Report

Fully Sharded Data Parallel (FSDP) is a distributed training approach designed to efficiently train very large neural network models across multiple GPUs or nodes by sharding (splitting) model parameters, gradients and optimizer states across devices. This method significantly reduces memory usage and enables scaling to models that would not fit on a single device.

Core Principles

  • Sharding: Each device (GPU) holds only a shard (subset) of the model parameters, gradients and optimizer states, not the full model.
  • On-Demand Gathering: Full model parameters are assembled (via all-gather) only when needed for computation (forward and backward passes) and released immediately after use.
  • Communication-Efficient: FSDP decomposes the standard all-reduce operation (used in DDP) into more efficient reduce-scatter and all-gather operations.

Step-by-Step Theory of Operation

1. Initialization

  • The model’s parameters, gradients and optimizer states are sharded across all participating devices.
  • Each device is responsible for only a subset of these tensors, drastically reducing per-device memory requirements.

2. Forward Pass

  • All-Gather: Before computation, each device collects all parameter shards from other devices to reconstruct the full set of parameters needed for its computation.
  • Computation: The forward pass is performed locally using the assembled parameters and the device’s data batch.
  • Release: After the computation, the gathered parameter shards (not originally belonging to the device) are freed from memory, retaining only the local shard.

3. Backward Pass

  • All-Gather (Optional): Some implementations may require another all-gather to reconstruct full parameters for gradient computation.
  • Backward Computation: Gradients are computed for the local batch.
  • Reduce-Scatter: Instead of all-reduce, gradients are reduced and scattered across devices so that each device ends up with only the gradient shard corresponding to its parameter shard.
  • Release: Temporary parameter shards are released, keeping only the local shard.

4. Optimizer Step

  • Each device updates its local parameter shard using the corresponding gradient shard and optimizer state.
  • No device ever needs the full optimizer state or full parameter set at once.

Collective Communication Operations

FSDP (Fully Sharded Data Parallel) splits model parameters, gradients and optimizer state across multiple devices to save memory and improve efficiency.

  • All-Reduce: Used in DDP to synchronize gradients/parameters across all devices (each device ends up with the full sum).
  • Reduce-Scatter: Used in FSDP to sum and distribute gradients efficiently (each device keeps only a chunk of the result).
  • All-Gather: Used in FSDP to assemble full parameters when needed (each device starts with a part, ends up with all parts).
FSDP
FSDP Allreduce

FSDP vs. DDP (Distributed Data Parallel)

Feature

DDP (Standard)

FSDP (Fully Sharded)

Model Replication

Full model on every device

Only a shard of the model per device

Memory Usage

High (full model, gradients, optimizer states)

Low (only shard, gradients, optimizer)

Communication Pattern

All-Reduce for gradients

Reduce-Scatter and All-Gather

Scalability

Limited by device memory

Enables much larger models

  • DDP: Each device keeps a full copy of the model, computes on a data shard and synchronizes gradients across all devices using allreduce, ensuring all models are updated identically.
  • FSDP: Each device stores only a shard of the model and optimizer state; before forward and backward passes, parameter shards are gathered as needed. Gradients are reduced and scattered during backpropagation and updated parameter shards are allgathered for the next computation, enabling memory efficiency and localized updates.

Implementation Details

  • Nested Wrapping: FSDP can wrap individual layers or groups of layers, gathering only what’s needed for each computation step, further optimizing memory and communication.
  • Offloading: FSDP can offload shards to CPU memory when not in use, reducing GPU memory requirements even further.
  • Flexible Sharding: Sharding can be done on different tensor dimensions, supporting advanced use cases and optimizations.

Key Advantages

  • Memory Efficiency: By sharding all major tensors, FSDP allows much larger models or batch sizes to fit in device memory.
  • Scalability: Enables training of models that are impossible to fit on a single device or even with standard DDP.
  • Communication Overlap: FSDP can overlap communication and computation for improved performance.

Explore