Scaling LLM Training w/ FSDP

May 24, 2024 (9mo ago)

Table of Contents

Estimating GPU Usage for Full Fine-Tuning

  • Model memory calculator
  • Assumptions: Adam optimizer and batch size of 1
  • Each parameter is 4 bytes
  • Backward pass ~= 2x the model size
  • The optimizer step ~= 4x the model size (1x model, 1x gradients, 2x optimizer)
  • Example: llama-3-8B takes 56GB of memory at half precision

Distributed Training

Types of Training

  • Single GPU: No distributed techniques at play
  • Distributed Data Parallelism (DDP): A full copy of the model exists on each device, but data is chunked between each GPU
  • Fully Sharded Data Parallelism (FSDP) & DeepSpeed (DS): Split chunks of the model and optimizer states across GPUs, allowing for training bigger models on smaller (multiple) GPUs

Deep Speed vs FSDP

  • Extremely similar, however mostly used different naming conventions for items and slight tweaks in the implementation

Fully Sharded Data Parallelism (FSDP)

FSDP
  • Splitting the model into shards
  • Certain chunks of the training loop will happen in the VRAM space
  • Occasionally torch needs to know what's happening with the other model chunk to align gradients

FSDP Parameters

sharding_strategy

  • Dictates the level of divving resources to perform
  • FULL_SHARD: Includes optimizer states, gradients, and parameters
  • SHARD_GRAD_OP: Includes optimizer states and gradients
  • NO_SHARD: Normal DDP
  • HYBRID_SHARD: Includes optimizer states, gradients, and parameters but each node has the full model. Can increase training speed exponentially since there are less communications between models.

auto_wrap_policy

  • How the model should be split
  • Can be either TRANSFORMER_BASED_WRAP or SIZE_BASED_WRAP
  • TRANSFORMER/fsdp_transformers_layer_cls_to_wrap:
    • Need to declare the layer
    • Generally transformers has good defaults
  • SIZE/fsdp_min_num_param:
    • Number of total parameters in a shard

offload_params

  • Offloads the parameters and gradients to the CPU if they can't fit into memory
  • Allows you to train much larger models locally, but will be much slower
  • Case: FFT of Llama-3-8B with fsdp_offload_params on 2x4090 GPUs was 72hrs, vs ~an hour or two when using 1xH100

cpu_ram_efficient_loading & sync_module_states

  • Uses the idea behind big model inference/the meta device to load in the model to the GPU in a low-ram scenario
  • Rather than needing modelsize* _gpus RAM, we can load the model on a single node and then send the weights directly to each shard when the time is right via sync_module states

Accelerate

Accelerate Intro

  • CLI interface
  • Training library
  • Big model inference

A CLI Interface

  • accelerate config: Configure the environment
  • accelerate estimate-memory: How to guess vRAM requirements
  • accelerate launch: How to run your script

A Training Library

  • Accelerate's DataLoaders and schedulers work off of a sharding mindset
  • Rather than repeating the same data across n nodes, we instead split it
  • Speeds up training linearly
  • Given a batch size of 16 on a single GPU, to recreate this across 8 GPUs you would use a batch size of 2
  • This also means the scheduler will be stepped n GPUs at a time per "global step"

A Training Library: Mixed Precision

  • Accelerate do not convert the model weights to BF16/FP16
  • Instead it wraps the forward pass with autocast to convert the gradients automatically
  • This preserves the original precision of the weights, which leads to stable training and better fine-tuning later on.
  • If you use .bf16() weights, you are STUCK in bf16 perminantly

References:

  1. Zach Mueller