Skip to main content

Hyperparameter Sweeps with Weights & Biases

Hyperparameter sweeps in flyte-sdk are managed through the Weights & Biases (W&B) integration, which automates sweep creation, agent coordination, and result tracking. This integration allows you to define sweep configurations in your Flyte workflow and launch parallel agents as distributed Flyte tasks.

Running a Parallel Hyperparameter Sweep

To run a sweep, you must define the sweep configuration using wandb_sweep_config, create the sweep with the @wandb_sweep decorator, and use @wandb_init on your objective function.

import asyncio
import wandb
import flyte
from flyteplugins.wandb import (
wandb_init,
wandb_sweep,
wandb_sweep_config,
wandb_config,
get_wandb_sweep_id,
get_wandb_context
)

# 1. Define the objective function
@wandb_init
def objective():
run = wandb.run
# Access hyperparameters from run.config
lr = run.config.learning_rate
batch_size = run.config.batch_size

# Training logic...
loss = 1.0 / (lr * batch_size)
run.log({"loss": loss})

# 2. Define the agent task
@wandb_sweep
@flyte.task
async def sweep_agent(sweep_id: str, count: int = 5):
# Use the standard wandb.agent to run trials
wandb.agent(
sweep_id,
function=objective,
count=count,
project=get_wandb_context().project
)

# 3. Define the controller task to launch parallel agents
@wandb_sweep
@flyte.task
async def run_parallel_sweep(num_agents: int = 3) -> str:
# Retrieve the sweep_id created by the @wandb_sweep decorator
sweep_id = get_wandb_sweep_id()

# Launch multiple agents in parallel using asyncio.gather
agent_tasks = [
sweep_agent(sweep_id=sweep_id, count=5)
for _ in range(num_agents)
]
await asyncio.gather(*agent_tasks)
return sweep_id

# 4. Execute with configuration
if __name__ == "__main__":
run = flyte.with_runcontext(
custom_context={
**wandb_config(project="my-project", entity="my-team"),
**wandb_sweep_config(
method="random",
metric={"name": "loss", "goal": "minimize"},
parameters={
"learning_rate": {"min": 0.0001, "max": 0.1},
"batch_size": {"values": [16, 32, 64]},
}
)
}
).run(run_parallel_sweep)

Key Components

  • wandb_sweep_config: Defines the sweep parameters (method, metric, parameters). This must be provided in the Flyte run context for @wandb_sweep to function.
  • @wandb_sweep: When applied to a Flyte task, this decorator creates a new W&B sweep using the configuration from the context. It injects the sweep_id into the Flyte context, making it available via get_wandb_sweep_id().
  • @wandb_init: Applied to the objective function or task, it initializes a W&B run. In a sweep context, it automatically links the run to the active sweep.
  • get_wandb_sweep_id(): Retrieves the ID of the sweep created by the parent task. This is essential for passing the sweep_id to wandb.agent.

Downloading Sweep Results

You can automatically download all run data (metrics, logs, and files) for a sweep after it completes by setting download_logs=True or using the download_wandb_sweep_logs helper.

from flyteplugins.wandb import download_wandb_sweep_logs

@wandb_sweep(download_logs=True)
@flyte.task
async def sweep_with_logs():
sweep_id = get_wandb_sweep_id()
# ... run agents ...
return sweep_id

# Or manually in a subsequent task
@flyte.task
async def process_results(sweep_id: str):
# Downloads all runs in the sweep to a Flyte Dir
logs_dir = await download_wandb_sweep_logs(sweep_id)
print(f"Logs downloaded to: {logs_dir.remote_source}")

Troubleshooting and Best Practices

  • Decorator Order: The @wandb_sweep and @wandb_init decorators must be the outermost decorators on a Flyte task.
  • Async Support: While Flyte tasks can be async, W&B's Elastic (distributed) training integration does not support async task functions. Use synchronous functions for distributed training tasks.
  • Missing Configuration: If @wandb_sweep is used without a wandb_sweep_config in the context, it will raise a RuntimeError. Ensure you use flyte.with_runcontext to provide the necessary sweep definition.
  • Distributed Logs: The download_logs=True parameter is not supported for distributed tasks (tasks using Elastic plugin configuration). For these tasks, logs must be managed or downloaded manually.
  • Authentication: Ensure the WANDB_API_KEY is available in the task environment. You can provide this via flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY") in your TaskEnvironment.