Fully Sharded Data Parallel (FSDP) Last Updated : 23 Jul, 2025 Comments Improve Suggest changes Like Article Like Report Fully Sharded Data Parallel (FSDP) is a distributed training approach designed to efficiently train very large neural network models across multiple GPUs or nodes by sharding (splitting) model parameters, gradients and optimizer states across devices. This method significantly reduces memory usage and enables scaling to models that would not fit on a single device.Core PrinciplesSharding: Each device (GPU) holds only a shard (subset) of the model parameters, gradients and optimizer states, not the full model.On-Demand Gathering: Full model parameters are assembled (via all-gather) only when needed for computation (forward and backward passes) and released immediately after use.Communication-Efficient: FSDP decomposes the standard all-reduce operation (used in DDP) into more efficient reduce-scatter and all-gather operations.Step-by-Step Theory of Operation1. InitializationThe model’s parameters, gradients and optimizer states are sharded across all participating devices.Each device is responsible for only a subset of these tensors, drastically reducing per-device memory requirements.2. Forward PassAll-Gather: Before computation, each device collects all parameter shards from other devices to reconstruct the full set of parameters needed for its computation.Computation: The forward pass is performed locally using the assembled parameters and the device’s data batch.Release: After the computation, the gathered parameter shards (not originally belonging to the device) are freed from memory, retaining only the local shard.3. Backward PassAll-Gather (Optional): Some implementations may require another all-gather to reconstruct full parameters for gradient computation.Backward Computation: Gradients are computed for the local batch.Reduce-Scatter: Instead of all-reduce, gradients are reduced and scattered across devices so that each device ends up with only the gradient shard corresponding to its parameter shard.Release: Temporary parameter shards are released, keeping only the local shard.4. Optimizer StepEach device updates its local parameter shard using the corresponding gradient shard and optimizer state.No device ever needs the full optimizer state or full parameter set at once.Collective Communication OperationsFSDP (Fully Sharded Data Parallel) splits model parameters, gradients and optimizer state across multiple devices to save memory and improve efficiency.All-Reduce: Used in DDP to synchronize gradients/parameters across all devices (each device ends up with the full sum).Reduce-Scatter: Used in FSDP to sum and distribute gradients efficiently (each device keeps only a chunk of the result).All-Gather: Used in FSDP to assemble full parameters when needed (each device starts with a part, ends up with all parts).FSDP AllreduceFSDP vs. DDP (Distributed Data Parallel)FeatureDDP (Standard)FSDP (Fully Sharded)Model ReplicationFull model on every deviceOnly a shard of the model per deviceMemory UsageHigh (full model, gradients, optimizer states)Low (only shard, gradients, optimizer)Communication PatternAll-Reduce for gradientsReduce-Scatter and All-GatherScalabilityLimited by device memoryEnables much larger modelsDDP: Each device keeps a full copy of the model, computes on a data shard and synchronizes gradients across all devices using allreduce, ensuring all models are updated identically.FSDP: Each device stores only a shard of the model and optimizer state; before forward and backward passes, parameter shards are gathered as needed. Gradients are reduced and scattered during backpropagation and updated parameter shards are allgathered for the next computation, enabling memory efficiency and localized updates.Implementation DetailsNested Wrapping: FSDP can wrap individual layers or groups of layers, gathering only what’s needed for each computation step, further optimizing memory and communication.Offloading: FSDP can offload shards to CPU memory when not in use, reducing GPU memory requirements even further.Flexible Sharding: Sharding can be done on different tensor dimensions, supporting advanced use cases and optimizations.Key AdvantagesMemory Efficiency: By sharding all major tensors, FSDP allows much larger models or batch sizes to fit in device memory.Scalability: Enables training of models that are impossible to fit on a single device or even with standard DDP.Communication Overlap: FSDP can overlap communication and computation for improved performance. Create Quiz Comment S shambhava9ex Follow 0 Improve S shambhava9ex Follow 0 Improve Article Tags : Deep Learning AI-ML-DS With Python Explore Deep Learning BasicsIntroduction to Deep Learning6 min readArtificial intelligence vs Machine Learning vs Deep Learning3 min readDeep Learning Examples: Practical Applications in Real Life3 min readChallenges in Deep Learning7 min readWhy Deep Learning is Important5 min readNeural Networks BasicsWhat is a Neural Network?10 min readTypes of Neural Networks7 min readLayers in Artificial Neural Networks (ANN)4 min readActivation functions in Neural Networks5 min readFeedforward Neural Network6 min readBackpropagation in Neural Network9 min readDeep Learning ModelsConvolutional Neural Network (CNN) in Machine Learning5 min readIntroduction to Recurrent Neural Networks10 min readWhat is LSTM - Long Short Term Memory?5 min readGated Recurrent Unit Networks6 min readTransformers in Machine Learning4 min readAutoencoders in Machine Learning7 min readGenerative Adversarial Network (GAN)11 min readDeep Learning FrameworksTensorFlow Tutorial2 min readPyTorch Tutorial6 min readCaffe : Deep Learning Framework8 min readApache MXNet: The Scalable and Flexible Deep Learning Framework6 min readTheano in Python4 min readModel EvaluationGradient Descent Algorithm in Machine Learning15+ min readMomentum-based Gradient Optimizer - ML4 min readAdagrad Optimizer in Deep Learning6 min readRMSProp Optimizer in Deep Learning5 min readWhat is Adam Optimizer?4 min readDeep Learning ProjectsLung Cancer Detection using Convolutional Neural Network (CNN)7 min readCat & Dog Classification using Convolutional Neural Network in Python5 min readSentiment Analysis with an Recurrent Neural Networks (RNN)5 min readText Generation using Recurrent Long Short Term Memory Network4 min readMachine Translation with Transformer in Python6 min readDeep Learning Interview Questions15+ min read Like