Skip to main content

Distributed Training with Weights & Biases

When you run distributed training jobs with Flyte's Elastic task type, managing Weights & Biases (W&B) logging requires coordinating multiple processes across one or more nodes. The flyte-sdk provides the @wandb_init decorator to automatically handle rank detection and run initialization based on your desired logging strategy.

Control Logging with Run Mode and Rank Scope

The flyte-sdk uses two primary parameters in the @wandb_init decorator to control distributed logging:

  • run_mode: Determines how runs are created or shared.
    • "auto" (default): Only the primary rank initializes a run; others receive None.
    • "shared": All ranks initialize and log to the same shared W&B run.
    • "new": Every rank creates its own unique W&B run.
  • rank_scope: Defines the boundary for the primary rank.
    • "global" (default): One primary rank for the entire multi-node cluster (Global Rank 0).
    • "worker": One primary rank per node/worker (Local Rank 0 on each node).

Pattern 1: Single Primary Logger (Default)

In the default configuration, only the global rank 0 process initializes a W&B run. This is the most common pattern for distributed training where you only want to track aggregate metrics from the master process.

from flyteplugins.pytorch.task import Elastic
from flyteplugins.wandb import wandb_init, get_wandb_run

# Default: run_mode="auto", rank_scope="global"
@wandb_init(project="my-project")
@torch_env.task
def train_distributed():
# Only global rank 0 gets a W&B run object
# Other ranks receive None from get_wandb_run()
run = get_wandb_run()

if run:
run.log({"global_loss": 0.01})

Pattern 2: Shared Run Across All Ranks

If you want every process in the distributed cluster to log to the same W&B run, use run_mode="shared". The flyte-sdk configures W&B's internal shared mode, allowing concurrent logging from multiple processes to a single run ID.

@wandb_init(run_mode="shared", project="my-project")
@torch_env.task
def train_shared_run():
# Every rank gets the same W&B run object
run = get_wandb_run()

# Each rank logs its own specific metric to the shared run
run.log({"rank_specific_metric": 0.5})

Pattern 3: Individual Runs per Rank

To track the performance or system metrics of every individual process separately, use run_mode="new". This creates a unique run for every rank, automatically grouped together in the W&B UI.

@wandb_init(run_mode="new", project="my-project")
@torch_env.task
def train_individual_runs():
# Every rank gets a unique W&B run object
# Run IDs follow the pattern: {base_id}-rank-{global_rank}
run = get_wandb_run()
run.log({"local_step_time": 0.123})

Pattern 4: Worker-Level Logging

In large multi-node setups, you may want one log per node rather than one for the whole cluster or one for every single GPU. Setting rank_scope="worker" shifts the primary rank detection to the local node level.

# Creates 1 run per node (Local Rank 0 on each node logs)
@wandb_init(run_mode="auto", rank_scope="worker", project="my-project")
@torch_env.task
def train_per_node():
run = get_wandb_run()
if run:
# This code executes only on Local Rank 0 of each node
run.log({"node_average_loss": 0.01})

Accessing Distributed Metadata

If you need to implement custom logic based on the distributed topology, use get_distributed_info(). This helper auto-detects environment variables set by torchrun or Flyte's Elastic plugin.

from flyteplugins.wandb import get_distributed_info

@wandb_init
@torch_env.task
def custom_logic_task():
dist_info = get_distributed_info()
if dist_info:
# dist_info contains: rank, local_rank, world_size,
# local_world_size, worker_index, num_workers
print(f"I am rank {dist_info['rank']} on worker {dist_info['worker_index']}")

Troubleshooting and Limitations

  • Decorator Order: The @wandb_init decorator must be the outermost decorator on your Flyte task.
  • Async Support: Distributed training using the Elastic plugin does not support async task functions in flyte-sdk. You must use synchronous functions.
  • Log Downloading: The download_logs=True parameter is not supported for distributed tasks. If enabled, flyte-sdk will issue a warning and skip the download to avoid conflicts between multiple workers attempting to download the same run data.
  • None Checks: When using the default run_mode="auto", always check if get_wandb_run() returns None before attempting to log, as it will return None on all non-primary ranks.