Back to library

πŸ”ͺUnderstand FSDP Sharding Strategies

Walk every FSDP sharding strategy across the same toy transformer until all-gather and reduce-scatter become numbers, not folklore. By the end you can pick FULL_SHARD vs SHARD_GRAD_OP vs HYBRID_SHARD for a 7B model on 16 GPUs and defend it.

Applied14 drops~2-week path Β· 5–8 min/daytechnology

Phase 1FSDP as PyTorch's Native ZeRO-3

See FSDP as PyTorch's native ZeRO-3

4 drops
  1. FSDP is ZeRO-3 wearing a PyTorch hat

    6 min

    FSDP and DeepSpeed ZeRO-3 implement the same idea β€” shard params, grads, and optimizer state, all-gather on demand β€” with different APIs and slightly different defaults.

  2. All-gather on the way in, reduce-scatter on the way out

    7 min

    Each FSDP unit costs one all-gather in forward, one all-gather plus one reduce-scatter in backward β€” that's the whole comm story.

  3. The FlatParameter is FSDP's unit of work β€” and you choose it

    7 min

    An auto-wrap policy decides which submodules get fused into one FlatParameter; that decision controls both peak memory and overlap, more than the sharding strategy does.

  4. MixedPrecision is a third axis, orthogonal to sharding

    7 min

    FSDP's MixedPrecision config lets you store params in one dtype, compute in another, and reduce gradients in a third β€” independent of which sharding strategy you pick.

Phase 2Wrapping a Transformer and Reading FlatParameters

Wrap a transformer and watch FlatParameters form

5 drops
  1. Wrap a 2-layer transformer in 20 lines and inspect it

    7 min

    You can see exactly which submodules became FlatParameters by walking `model.named_modules()` after wrapping β€” no profiler needed for the first sanity check.

  2. Flip FULL_SHARD to SHARD_GRAD_OP and measure peak memory

    7 min

    FULL_SHARD shards params; SHARD_GRAD_OP keeps params replicated β€” the difference shows up as a roughly 2x params-worth of memory per rank.

  3. Activation checkpointing makes sharding look better than it is

    7 min

    FSDP's memory savings target static state (params + grads + optimizer); activations are an independent axis that often dominates peak memory, and activation checkpointing is the right lever for them.

  4. BackwardPrefetch.BACKWARD_PRE is the throughput knob nobody mentions

    6 min

    Setting `backward_prefetch=BackwardPrefetch.BACKWARD_PRE` overlaps the next unit's all-gather with the current unit's backward compute β€” often a 5-15% throughput win for free.

  5. Three knobs, in priority order: wrap policy, strategy, prefetch

    6 min

    Most FSDP tuning collapses to three decisions in this order: wrap policy first (overlap), then strategy (memory), then prefetch (throughput) β€” get them in the right order and you're 90% of the way to a good config.

Phase 3Choosing the Sharding Mode for the Topology

Pick the sharding mode that fits the topology

4 drops
  1. Single-node, plenty of memory: do you really need FULL_SHARD?

    7 min

    Single-node, plenty of memory: do you really need FULL_SHARD?

  2. Two nodes, slow inter-node link: when HYBRID_SHARD wins

    7 min

    Two nodes, slow inter-node link: when HYBRID_SHARD wins

  3. Activations OOM on long context: which knob first?

    7 min

    Activations OOM on long context: which knob first?

  4. The throughput dropped 20% after enabling FSDP β€” what's the bug?

    8 min

    The throughput dropped 20% after enabling FSDP β€” what's the bug?

Phase 4Defend a Strategy for 7B on 16 GPUs

Defend a strategy for 7B on 16 GPUs

1 drop
  1. 7B model, 16 GPUs, 2 nodes: pick a strategy and defend it

    8 min

    7B model, 16 GPUs, 2 nodes: pick a strategy and defend it

Frequently asked questions

What is FSDP and how is it different from DDP?
This is covered in the β€œUnderstand FSDP Sharding Strategies” learning path. Start with daily 5-minute micro-lessons that build from fundamentals to hands-on application.
What's the difference between FULL_SHARD and SHARD_GRAD_OP?
This is covered in the β€œUnderstand FSDP Sharding Strategies” learning path. Start with daily 5-minute micro-lessons that build from fundamentals to hands-on application.
When should I use HYBRID_SHARD instead of FULL_SHARD?
This is covered in the β€œUnderstand FSDP Sharding Strategies” learning path. Start with daily 5-minute micro-lessons that build from fundamentals to hands-on application.
How do auto-wrap policies actually decide what becomes a FlatParameter?
This is covered in the β€œUnderstand FSDP Sharding Strategies” learning path. Start with daily 5-minute micro-lessons that build from fundamentals to hands-on application.
Why does FSDP throughput drop 20% compared to DDP, and how do I fix it?
This is covered in the β€œUnderstand FSDP Sharding Strategies” learning path. Start with daily 5-minute micro-lessons that build from fundamentals to hands-on application.