Skip to main content

Distributed Compute Plugins

Flyte-sdk provides first-class support for distributed computing frameworks like Apache Spark, Ray, Dask, and PyTorch. These integrations allow you to scale your Python functions across a cluster of machines without managing the underlying infrastructure manually.

You enable these integrations by passing a configuration object to the plugin_config parameter of a TaskEnvironment. When a task is decorated with an environment that has a plugin_config, flyte-sdk uses a specialized TaskTemplate (registered in flyte.extend.TaskPluginRegistry) to handle the setup, execution, and teardown of the distributed cluster.

Apache Spark

When you need to process large datasets using Spark, you can configure a Spark object and pass it to your TaskEnvironment. flyte-sdk automatically initializes a SparkSession for you and makes it available via the Flyte context.

from flyteplugins.spark.task import Spark
import flyte

spark_conf = Spark(
spark_conf={
"spark.driver.memory": "3000M",
"spark.executor.memory": "1000M",
"spark.executor.instances": "2",
},
)

spark_env = flyte.TaskEnvironment(
name="spark_env",
plugin_config=spark_conf,
image="ghcr.io/flyteorg/spark-py:v3.4.0",
)

@spark_env.task
async def process_data(partitions: int = 3) -> float:
# Access the automatically initialized SparkSession
spark = flyte.ctx().data["spark_session"]

# Use Spark as usual
df = spark.range(1, 1000).toDF("number")
return df.count()

Internal Mechanism

The PysparkFunctionTask plugin in plugins/spark/src/flyteplugins/spark/task.py manages the lifecycle:

  • Pre-execution: In the pre hook, it builds a SparkSession using pyspark.sql.SparkSession.builder. If running in a cluster, it packages your code bundle into a ZIP file and adds it to the Spark context using sess.sparkContext.addPyFile(file_path) so executors can access your code.
  • Post-execution: In debug mode (where the action name is a0), the post hook ensures the SparkSession is stopped to prevent resource leaks in interactive environments.

Ray

Ray allows you to scale Python applications by distributing tasks and actors. To use Ray, define a RayJobConfig specifying the head and worker node configurations.

from flyteplugins.ray.task import HeadNodeConfig, RayJobConfig, WorkerNodeConfig
import ray
import flyte

ray_config = RayJobConfig(
head_node_config=HeadNodeConfig(ray_start_params={"log-color": "True"}),
worker_node_config=[WorkerNodeConfig(group_name="ray-group", replicas=2)],
runtime_env={"pip": ["numpy", "pandas"]},
)

ray_env = flyte.TaskEnvironment(
name="ray_env",
plugin_config=ray_config,
image="my-ray-image",
)

@ray_env.task
async def ray_task(n: int = 3) -> list[int]:
@ray.remote
def square(x):
return x * x

futures = [square.remote(i) for i in range(n)]
return ray.get(futures)

Internal Mechanism

RayFunctionTask in plugins/ray/src/flyteplugins/ray/task.py handles the integration:

  • Initialization: The pre hook calls ray.init(). If running in a cluster, it automatically sets the working_dir in the runtime_env to the current directory, ensuring your code is available on all nodes.
  • Compatibility: It supports both runtime_env and runtime_env_yaml (required for KubeRay >= 1.1.0) to maintain compatibility across different cluster versions.

Dask

Dask provides parallel computing for analytics. You can configure a Dask cluster by defining the Scheduler and WorkerGroup resources.

from flyteplugins.dask.task import Dask, Scheduler, WorkerGroup
import flyte
from distributed import Client

dask_config = Dask(
scheduler=Scheduler(resources=flyte.Resources(cpu="1", memory="2Gi")),
workers=WorkerGroup(number_of_workers=3, resources=flyte.Resources(cpu="2", memory="4Gi"))
)

dask_env = flyte.TaskEnvironment(
name="dask_env",
plugin_config=dask_config,
image="my-dask-image",
)

@dask_env.task
async def dask_task():
# Connect to the Dask cluster
client = Client()
# ... perform Dask operations ...

Internal Mechanism

DaskTask in plugins/dask/src/flyteplugins/dask/task.py uses Dask SchedulerPlugin and WorkerPlugin to distribute code:

  • Code Distribution: The pre hook registers DownloadCodeBundleWorkerPlugin and DownloadCodeBundleSchedulerPlugin. These plugins ensure that the flyte-sdk code bundle is downloaded and added to sys.path on every node in the Dask cluster before execution begins.

PyTorch Elastic

For distributed deep learning, flyte-sdk integrates with PyTorch Elastic (torch.distributed.run). This is configured using the Elastic class.

from flyteplugins.pytorch.task import Elastic, RunPolicy
import torch
import flyte

torch_env = flyte.TaskEnvironment(
name="torch_env",
plugin_config=Elastic(
nnodes=2, # Number of nodes
nproc_per_node=1, # Processes per node (usually 1 per GPU)
max_restarts=3,
),
image="pytorch-image",
)

@torch_env.task
def train_model(epochs: int):
torch.distributed.init_process_group("gloo")
# ... training logic ...

Resilience and Performance Features

TorchFunctionTask in plugins/pytorch/src/flyteplugins/pytorch/task.py includes several features to improve distributed training stability:

  • Zombie Watchdog: It starts a background thread (_start_zombie_watchdog) to detect a known PyTorch deadlock where the elastic agent hangs if all workers die simultaneously (e.g., during a CUDA OOM). The watchdog force-exits the process if it detects all worker children have become zombies.
  • NCCL Timeout Management: You can tune failure detection speed using nccl_collective_timeout_sec and nccl_heartbeat_timeout_sec. flyte-sdk propagates these to worker processes via environment variables like FLYTE_NCCL_COLLECTIVE_TIMEOUT_SEC.
  • Automatic Threading Limits: If nproc_per_node > 1, flyte-sdk sets OMP_NUM_THREADS=1 by default to prevent CPU over-subscription, matching standard torchrun behavior.

Dynamic Cluster Overrides

You can dynamically adjust cluster configurations at call time using the .override() method. This is useful when you want to scale the number of executors based on the input size.

@spark_env.task
async def spark_task():
...

@task_env.task
async def scaling_controller(num_executors: int):
# Create a new config with the desired number of instances
new_conf = Spark(
spark_conf={
"spark.executor.instances": str(num_executors),
"spark.driver.memory": "3000M",
}
)
# Override the task's plugin_config for this specific call
return await spark_task.override(plugin_config=new_conf)()

Note: plugin_config cannot be used if the TaskEnvironment is marked as reusable=True. If you need to override resources or other settings for a task with a plugin, ensure reusable is disabled.