Distributed Training with Weights & Biases
When you run distributed training jobs with Flyte's Elastic task type, managing Weights & Biases (W&B) logging requires coordinating multiple processes across one or more nodes. The flyte-sdk provides the @wandb_init decorator to automatically handle rank detection and run initialization based on your desired logging strategy.
Control Logging with Run Mode and Rank Scope
The flyte-sdk uses two primary parameters in the @wandb_init decorator to control distributed logging:
run_mode: Determines how runs are created or shared."auto"(default): Only the primary rank initializes a run; others receiveNone."shared": All ranks initialize and log to the same shared W&B run."new": Every rank creates its own unique W&B run.
rank_scope: Defines the boundary for the primary rank."global"(default): One primary rank for the entire multi-node cluster (Global Rank 0)."worker": One primary rank per node/worker (Local Rank 0 on each node).
Pattern 1: Single Primary Logger (Default)
In the default configuration, only the global rank 0 process initializes a W&B run. This is the most common pattern for distributed training where you only want to track aggregate metrics from the master process.
from flyteplugins.pytorch.task import Elastic
from flyteplugins.wandb import wandb_init, get_wandb_run
# Default: run_mode="auto", rank_scope="global"
@wandb_init(project="my-project")
@torch_env.task
def train_distributed():
# Only global rank 0 gets a W&B run object
# Other ranks receive None from get_wandb_run()
run = get_wandb_run()
if run:
run.log({"global_loss": 0.01})
Pattern 2: Shared Run Across All Ranks
If you want every process in the distributed cluster to log to the same W&B run, use run_mode="shared". The flyte-sdk configures W&B's internal shared mode, allowing concurrent logging from multiple processes to a single run ID.
@wandb_init(run_mode="shared", project="my-project")
@torch_env.task
def train_shared_run():
# Every rank gets the same W&B run object
run = get_wandb_run()
# Each rank logs its own specific metric to the shared run
run.log({"rank_specific_metric": 0.5})
Pattern 3: Individual Runs per Rank
To track the performance or system metrics of every individual process separately, use run_mode="new". This creates a unique run for every rank, automatically grouped together in the W&B UI.
@wandb_init(run_mode="new", project="my-project")
@torch_env.task
def train_individual_runs():
# Every rank gets a unique W&B run object
# Run IDs follow the pattern: {base_id}-rank-{global_rank}
run = get_wandb_run()
run.log({"local_step_time": 0.123})
Pattern 4: Worker-Level Logging
In large multi-node setups, you may want one log per node rather than one for the whole cluster or one for every single GPU. Setting rank_scope="worker" shifts the primary rank detection to the local node level.
# Creates 1 run per node (Local Rank 0 on each node logs)
@wandb_init(run_mode="auto", rank_scope="worker", project="my-project")
@torch_env.task
def train_per_node():
run = get_wandb_run()
if run:
# This code executes only on Local Rank 0 of each node
run.log({"node_average_loss": 0.01})
Accessing Distributed Metadata
If you need to implement custom logic based on the distributed topology, use get_distributed_info(). This helper auto-detects environment variables set by torchrun or Flyte's Elastic plugin.
from flyteplugins.wandb import get_distributed_info
@wandb_init
@torch_env.task
def custom_logic_task():
dist_info = get_distributed_info()
if dist_info:
# dist_info contains: rank, local_rank, world_size,
# local_world_size, worker_index, num_workers
print(f"I am rank {dist_info['rank']} on worker {dist_info['worker_index']}")
Troubleshooting and Limitations
- Decorator Order: The
@wandb_initdecorator must be the outermost decorator on your Flyte task. - Async Support: Distributed training using the
Elasticplugin does not supportasynctask functions in flyte-sdk. You must use synchronous functions. - Log Downloading: The
download_logs=Trueparameter is not supported for distributed tasks. If enabled, flyte-sdk will issue a warning and skip the download to avoid conflicts between multiple workers attempting to download the same run data. - None Checks: When using the default
run_mode="auto", always check ifget_wandb_run()returnsNonebefore attempting to log, as it will returnNoneon all non-primary ranks.