Hyperparameter Sweeps with Weights & Biases
Hyperparameter sweeps in flyte-sdk are managed through the Weights & Biases (W&B) integration, which automates sweep creation, agent coordination, and result tracking. This integration allows you to define sweep configurations in your Flyte workflow and launch parallel agents as distributed Flyte tasks.
Running a Parallel Hyperparameter Sweep
To run a sweep, you must define the sweep configuration using wandb_sweep_config, create the sweep with the @wandb_sweep decorator, and use @wandb_init on your objective function.
import asyncio
import wandb
import flyte
from flyteplugins.wandb import (
wandb_init,
wandb_sweep,
wandb_sweep_config,
wandb_config,
get_wandb_sweep_id,
get_wandb_context
)
# 1. Define the objective function
@wandb_init
def objective():
run = wandb.run
# Access hyperparameters from run.config
lr = run.config.learning_rate
batch_size = run.config.batch_size
# Training logic...
loss = 1.0 / (lr * batch_size)
run.log({"loss": loss})
# 2. Define the agent task
@wandb_sweep
@flyte.task
async def sweep_agent(sweep_id: str, count: int = 5):
# Use the standard wandb.agent to run trials
wandb.agent(
sweep_id,
function=objective,
count=count,
project=get_wandb_context().project
)
# 3. Define the controller task to launch parallel agents
@wandb_sweep
@flyte.task
async def run_parallel_sweep(num_agents: int = 3) -> str:
# Retrieve the sweep_id created by the @wandb_sweep decorator
sweep_id = get_wandb_sweep_id()
# Launch multiple agents in parallel using asyncio.gather
agent_tasks = [
sweep_agent(sweep_id=sweep_id, count=5)
for _ in range(num_agents)
]
await asyncio.gather(*agent_tasks)
return sweep_id
# 4. Execute with configuration
if __name__ == "__main__":
run = flyte.with_runcontext(
custom_context={
**wandb_config(project="my-project", entity="my-team"),
**wandb_sweep_config(
method="random",
metric={"name": "loss", "goal": "minimize"},
parameters={
"learning_rate": {"min": 0.0001, "max": 0.1},
"batch_size": {"values": [16, 32, 64]},
}
)
}
).run(run_parallel_sweep)
Key Components
wandb_sweep_config: Defines the sweep parameters (method, metric, parameters). This must be provided in the Flyte run context for@wandb_sweepto function.@wandb_sweep: When applied to a Flyte task, this decorator creates a new W&B sweep using the configuration from the context. It injects thesweep_idinto the Flyte context, making it available viaget_wandb_sweep_id().@wandb_init: Applied to the objective function or task, it initializes a W&B run. In a sweep context, it automatically links the run to the active sweep.get_wandb_sweep_id(): Retrieves the ID of the sweep created by the parent task. This is essential for passing thesweep_idtowandb.agent.
Downloading Sweep Results
You can automatically download all run data (metrics, logs, and files) for a sweep after it completes by setting download_logs=True or using the download_wandb_sweep_logs helper.
from flyteplugins.wandb import download_wandb_sweep_logs
@wandb_sweep(download_logs=True)
@flyte.task
async def sweep_with_logs():
sweep_id = get_wandb_sweep_id()
# ... run agents ...
return sweep_id
# Or manually in a subsequent task
@flyte.task
async def process_results(sweep_id: str):
# Downloads all runs in the sweep to a Flyte Dir
logs_dir = await download_wandb_sweep_logs(sweep_id)
print(f"Logs downloaded to: {logs_dir.remote_source}")
Troubleshooting and Best Practices
- Decorator Order: The
@wandb_sweepand@wandb_initdecorators must be the outermost decorators on a Flyte task. - Async Support: While Flyte tasks can be
async, W&B's Elastic (distributed) training integration does not supportasynctask functions. Use synchronous functions for distributed training tasks. - Missing Configuration: If
@wandb_sweepis used without awandb_sweep_configin the context, it will raise aRuntimeError. Ensure you useflyte.with_runcontextto provide the necessary sweep definition. - Distributed Logs: The
download_logs=Trueparameter is not supported for distributed tasks (tasks usingElasticplugin configuration). For these tasks, logs must be managed or downloaded manually. - Authentication: Ensure the
WANDB_API_KEYis available in the task environment. You can provide this viaflyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")in yourTaskEnvironment.