Skip to main content

MLflow Tracking and UI Links

The flyte-sdk MLflow plugin automates the lifecycle of MLflow runs within Flyte tasks. It handles starting and ending runs, managing experiments, and providing seamless navigation between the Flyte UI and the MLflow UI through auto-generated links.

Managing Runs with @mlflow_run

When you want to track a Flyte task's execution in MLflow, apply the @mlflow_run decorator. This decorator manages the mlflow.start_run() and mlflow.end_run() calls for you, ensuring that the MLflow run state is correctly synchronized with the Flyte task's execution.

Decorator Placement

The @mlflow_run decorator must be the outermost decorator on your task function. This ensures it can wrap the entire task execution, including any other decorators like @env.task.

from flyteplugins.mlflow import mlflow_run

@mlflow_run(experiment_name="/my-experiment")
@env.task
async def my_task():
import mlflow
mlflow.log_param("param1", 42)

Run Modes

The run_mode parameter in @mlflow_run (defined in plugins/mlflow/src/flyteplugins/mlflow/_context.py) determines how the task interacts with existing MLflow runs:

  • auto (Default): Reuses a parent task's MLflow run if one is active. If no run is active, it creates a new one.
  • new: Always creates a new, independent MLflow run, even if a parent run exists.
  • nested: Creates a new MLflow run that is marked as a child of the parent run using the mlflow.parentRunId tag. This is particularly useful for Hyperparameter Optimization (HPO) where each trial is a separate Flyte task.

Example of nested runs for HPO from plugins/mlflow/examples/example_mlflow_hpo.py:

@mlflow_run(run_mode="nested")
@env.task(links=(Mlflow()))
async def run_trial(trial_number: int, **params) -> float:
import mlflow
mlflow.log_params(params)
# ... training logic ...
return rmse

Autologging

You can enable MLflow's autologging capabilities directly through the decorator. This automatically captures parameters, metrics, and models for supported frameworks like Scikit-Learn, PyTorch, and XGBoost.

Pass autolog=True and optionally specify the framework to use framework-specific autologging.

@mlflow_run(autolog=True, framework="sklearn")
@env.task
async def train_model():
from sklearn.linear_model import LogisticRegression
model = LogisticRegression()
model.fit(X, y) # Parameters and metrics are logged automatically

You can further control autologging behavior using log_models, log_datasets, and autolog_kwargs, which are passed directly to the underlying mlflow.autolog() call.

The Mlflow link class (found in plugins/mlflow/src/flyteplugins/mlflow/_link.py) allows you to attach links to the MLflow UI directly to your Flyte tasks. When configured, these links appear in the Flyte UI, providing a direct path to the specific MLflow run associated with that task.

To enable automatic link generation, you must provide a link_host. You can also provide a link_template if your MLflow UI uses a non-standard URL structure.

from flyteplugins.mlflow import Mlflow, mlflow_config

# Configure globally for the workflow execution
run = flyte.with_runcontext(
custom_context=mlflow_config(
link_host="https://mlflow.example.com",
link_template="{host}/#experiments/{experiment_id}/runs/{run_id}",
),
).run(parent_task)

# In the task definition, include the Mlflow() link
@mlflow_run
@env.task(links=(Mlflow(),))
async def my_task():
...

For tasks using run_mode="nested", the Mlflow link automatically detects the parent relationship. Since the specific child run ID is often not known until the task starts, the link in the Flyte UI will default to the parent run and be labeled "MLflow (parent)".

Global and Local Configuration

The mlflow_config() function (in plugins/mlflow/src/flyteplugins/mlflow/_context.py) provides a way to set MLflow parameters that apply across multiple tasks.

Global Configuration

Use flyte.with_runcontext() to set global defaults like the tracking_uri or experiment_name for an entire workflow run.

from flyteplugins.mlflow import mlflow_config

config = mlflow_config(
tracking_uri="http://mlflow-server:5000",
experiment_name="/shared-experiment",
link_host="http://mlflow-ui"
)

flyte.with_runcontext(custom_context=config).run(my_workflow)

Local Overrides

mlflow_config() also acts as a context manager for per-task overrides within a parent task.

@mlflow_run
@env.task
async def parent_task():
with mlflow_config(run_mode="new", autolog=True):
# This child task will create a new run with autologging enabled
await child_task()

Distributed Training

In distributed training scenarios (e.g., using PyTorch DistributedDataParallel), flyte-sdk ensures that only the process with RANK=0 performs MLflow logging. This prevents duplicate runs and conflicting log entries.

The @mlflow_run decorator automatically checks the RANK environment variable. If the rank is non-zero, the decorator skips all MLflow initialization, and get_mlflow_run() will return None.

from flyteplugins.mlflow import get_mlflow_run

@mlflow_run
@env.task
def distributed_task():
run = get_mlflow_run()
if run:
# This block only executes on Rank 0
import mlflow
mlflow.log_metric("accuracy", 0.95)

Accessing the Active Run

If you need to access the mlflow.ActiveRun object directly within your task (e.g., to get the run_id), use the get_mlflow_run() utility.

from flyteplugins.mlflow import get_mlflow_run

@mlflow_run
@env.task
async def my_task():
run = get_mlflow_run()
if run:
print(f"Active MLflow Run ID: {run.info.run_id}")