Skip to main content

Ecosystem & Tooling Integrations

flyte-sdk provides first-class integrations with popular machine learning and configuration tools, allowing you to track experiments, execute notebooks as tasks, and manage complex configurations without leaving the Flyte ecosystem.

MLflow Integration

You can automatically track experiments and log metrics, parameters, and models to MLflow by using the @mlflow_run decorator. This integration handles run initialization and provides direct links to the MLflow UI from the Flyte console.

Automatic Logging

Use autolog=True and specify your framework (e.g., "sklearn", "pytorch", "tensorflow") to capture training details without manual logging calls.

from flyteplugins.mlflow import Mlflow, mlflow_run
import flyte

env = flyte.TaskEnvironment(name="mlflow-env")

@mlflow_run(autolog=True, framework="sklearn")
@env.task(links=(Mlflow(),))
async def train_model(n_samples: int = 100) -> None:
from sklearn.linear_model import LogisticRegression
import numpy as np

# Generate data and train
X = np.random.randn(n_samples, 4)
y = (X[:, 0] + X[:, 1] > 0).astype(int)

model = LogisticRegression()
# MLflow autolog captures parameters and metrics here
model.fit(X, y)

Manual Logging and Run Sharing

If you need to log custom metrics or share a single MLflow run across multiple tasks, use get_mlflow_run() to access the active run object.

from flyteplugins.mlflow import get_mlflow_run, mlflow_run

@mlflow_run
@env.task
async def log_custom_metric(value: float):
run = get_mlflow_run()
import mlflow
mlflow.log_metric("custom_accuracy", value)

Note: The @mlflow_run decorator must be placed above the @env.task decorator to ensure the MLflow context is initialized before the task starts.


Weights & Biases Integration

The Weights & Biases (W&B) integration allows you to initialize runs, perform hyperparameter sweeps, and handle distributed training logging (such as PyTorch DDP) seamlessly.

Initializing Runs

Use the @wandb_init decorator to start a W&B run. You can retrieve the run object using get_wandb_run().

from flyteplugins.wandb import wandb_init, get_wandb_run
import flyte

@wandb_init(project="my-ml-project", entity="my-team")
@env.task
async def train_with_wandb(lr: float) -> str:
wandb_run = get_wandb_run()
wandb_run.log({"learning_rate": lr, "loss": 0.05})
return wandb_run.id

Hyperparameter Sweeps

flyte-sdk supports parallel hyperparameter optimization using @wandb_sweep. This allows you to launch multiple "agents" as separate Flyte tasks that pull trials from the W&B cloud controller.

from flyteplugins.wandb import wandb_sweep, get_wandb_sweep_id
import wandb

@wandb_sweep
@env.task
async def sweep_agent(sweep_id: str, count: int = 5):
# This agent runs 'count' trials from the specified sweep
wandb.agent(sweep_id, function=my_objective_fn, count=count)

@wandb_sweep
@env.task
async def run_parallel_sweep():
sweep_id = get_wandb_sweep_id()
# Launch 3 agents in parallel
await asyncio.gather(*[sweep_agent(sweep_id=sweep_id) for _ in range(3)])

Decorator Order: Like MLflow, @wandb_init and @wandb_sweep must be the outermost decorators (placed above @env.task).


Jupyter Notebooks (Papermill)

The NotebookTask allows you to execute Jupyter notebooks as Flyte tasks. This is useful for data exploration or reporting where the notebook itself is the primary artifact.

Defining a Notebook Task

You define a NotebookTask by specifying the path to the .ipynb file and mapping its inputs and outputs.

from flyteplugins.papermill import NotebookTask
import flyte

add_numbers = NotebookTask(
name="notebook_math",
notebook_path="notebooks/math_operations.ipynb",
task_environment=env,
inputs={"x": int, "y": float},
outputs={"result": float},
)

Recording Outputs in the Notebook

Inside your Jupyter notebook, use the record_outputs helper to return values back to Flyte.

# Inside notebooks/math_operations.ipynb
from flyteplugins.papermill import record_outputs

# 'x' and 'y' are injected by Papermill from the task inputs
result = x + y
record_outputs(result=result)

Gotcha: NotebookTask inputs must be JSON-serializable primitives or specific Flyte types like flyte.File, flyte.Dir, or pandas.DataFrame.


Hydra Integration

flyte-sdk integrates with Hydra to manage complex configurations and perform multi-run sweeps (e.g., grid search or Optuna optimization).

Running Tasks with Hydra Configs

Use hydra_run to execute a Flyte task with a configuration composed by Hydra.

from flyteplugins.hydra import hydra_run

# Assuming 'train_task' is a Flyte task and 'conf/config.yaml' exists
run = hydra_run(
target=train_task,
config_path="conf",
config_name="config",
overrides=["model.lr=0.01", "batch_size=32"],
)

Performing Sweeps

Use hydra_sweep to launch multiple runs based on Hydra's multirun syntax.

from flyteplugins.hydra import hydra_sweep

# Launches a grid search over learning rates
runs = hydra_sweep(
target=train_task,
config_path="conf",
config_name="config",
overrides=["model.lr=0.001,0.01,0.1"],
)

Hydra Launcher

You can also use Flyte as a Hydra launcher directly from the command line if your script uses @hydra.main.

python train.py hydra/launcher=flyte hydra.launcher.mode=remote

When using the launcher, child task overrides (like memory or CPU) should be applied in your code using apply_task_env to ensure the Hydra-composed environment propagates correctly.