Skip to main content

Weights & Biases Run Integration

The Weights & Biases (W&B) integration in flyte-sdk provides a seamless way to track experiments, manage runs, and visualize results directly from Flyte tasks. It automatically handles run initialization, lifecycle management, and generates UI links in the Flyte console.

Basic Task Integration

To track a Flyte task in W&B, use the @wandb_init decorator. This decorator must be the outermost decorator on your task function. Inside the task, use get_wandb_run() to retrieve the active W&B run object for logging metrics or artifacts.

from flyteplugins.wandb import wandb_init, get_wandb_run
import flyte

@wandb_init(project="my-project", entity="my-team")
@flyte.task
async def train_model(learning_rate: float) -> str:
# Retrieve the automatically initialized run
wandb_run = get_wandb_run()

# Log metrics as you normally would with wandb
wandb_run.log({"loss": 0.5, "learning_rate": learning_rate})

return wandb_run.id

The @wandb_init decorator in plugins/wandb/src/flyteplugins/wandb/_decorator.py performs several automated steps:

  1. Initialization: Calls wandb.init() with parameters derived from the decorator or context.
  2. ID Generation: If no ID is provided, it generates a unique run ID based on the Flyte action name (ctx.action.run_name).
  3. UI Links: Automatically adds a "Weights & Biases" link to the task in the Flyte UI using the Wandb link provider.
  4. Lifecycle: Automatically calls run.finish() when the task completes.

Configuration Management

You can set default W&B settings at the workflow level or for a specific scope using wandb_config. This avoids repeating project and entity names in every decorator.

from flyteplugins.wandb import wandb_config, wandb_init
import flyte

# Set workflow-level defaults
@flyte.workflow
def my_workflow(lr: float):
with flyte.with_runcontext(
custom_context=wandb_config(project="my-project", entity="my-team")
):
train_model(learning_rate=lr)

@wandb_init
@flyte.task
async def train_model(learning_rate: float):
run = get_wandb_run()
run.log({"lr": learning_rate})

The wandb_config function in plugins/wandb/src/flyteplugins/wandb/_context.py creates a configuration object that is stored in Flyte's custom_context. The @wandb_init decorator reads from this context to populate its parameters.

Run Management and Reuse

In complex workflows with parent and child tasks, flyte-sdk allows you to control how W&B runs are shared or created using the run_mode parameter.

ModeBehavior
auto (default)Creates a new run if no parent run exists; otherwise, reuses the parent's run ID.
newAlways creates a new, unique W&B run.
sharedAlways attempts to share the parent's run ID.

Example: Parent/Child Run Sharing

@wandb_init(run_mode="new")
@flyte.task
async def parent_task():
run = get_wandb_run()
run.log({"parent_val": 1})

# Child reuses parent's run by default (run_mode="auto")
await child_task()

@wandb_init
@flyte.task
async def child_task():
run = get_wandb_run()
run.log({"child_val": 2}) # Logs to the same run as parent

Distributed Training

The W&B plugin integrates with Flyte's Elastic plugin to handle distributed training (e.g., PyTorch DDP). It auto-detects distributed environment variables like RANK and WORLD_SIZE.

The rank_scope parameter controls which processes initialize W&B runs:

  • global (default): Only the global rank 0 process initializes a run. Other ranks receive None from get_wandb_run().
  • worker: The local rank 0 process on each node/worker initializes a run (resulting in one run per node).
from flyteplugins.pytorch.task import Elastic

@wandb_init(rank_scope="global")
@flyte.task(plugin_config=Elastic(nnodes=2, nproc_per_node=2))
async def train_distributed():
run = get_wandb_run()
if run:
# Only rank 0 executes this
run.log({"global_loss": 0.1})

For scenarios where all ranks must log to the same run, use run_mode="shared". The plugin will configure W&B's internal "shared" mode, designating the primary rank to manage the run state.

Hyperparameter Sweeps

Use @wandb_sweep to manage W&B sweeps. This decorator handles sweep creation and provides the sweep_id to the task.

from flyteplugins.wandb import wandb_sweep, wandb_sweep_config, get_wandb_sweep_id

@wandb_sweep
@flyte.task
async def run_sweep():
sweep_id = get_wandb_sweep_id()
# Launch agents using wandb.agent(sweep_id, ...)
...

# Configure the sweep in the workflow
@flyte.workflow
def sweep_workflow():
with flyte.with_runcontext(
custom_context=wandb_sweep_config(
method="random",
metric={"name": "loss", "goal": "minimize"},
parameters={"lr": {"min": 0.001, "max": 0.1}}
)
):
run_sweep()

The wandb_sweep decorator (found in plugins/wandb/src/flyteplugins/wandb/_decorator.py) deterministicly generates a sweep name and registers it with W&B if it doesn't already exist in the context.

Data Retrieval and UI Integration

Automatic Log Downloading

You can configure flyte-sdk to automatically download W&B run logs and metrics after a task completes by setting download_logs=True. These logs appear as a Dir output in the Flyte UI trace.

@wandb_init(download_logs=True)
@flyte.task
async def train_and_export():
...

Manual Data Retrieval

To programmatically access data from a completed run (e.g., in a downstream evaluation task), use download_wandb_run_dir. This utility downloads all synced files, summary.json, and optionally metrics_history.json from the W&B cloud.

from flyteplugins.wandb import download_wandb_run_dir

@flyte.task
async def evaluate(run_id: str):
local_path = download_wandb_run_dir(run_id=run_id)
# Access files in local_path/summary.json, etc.

The implementation in plugins/wandb/src/flyteplugins/wandb/__init__.py uses the wandb.Api() to fetch run data, ensuring that even if the task ran on a different node, its results are accessible.

Configuration Requirements

To authenticate with Weights & Biases, ensure the WANDB_API_KEY environment variable is set in your Flyte execution environment. This is required for both logging metrics and downloading run data via the API.