Distributed Compute Plugins
Flyte-sdk provides first-class support for distributed computing frameworks like Apache Spark, Ray, Dask, and PyTorch. These integrations allow you to scale your Python functions across a cluster of machines without managing the underlying infrastructure manually.
You enable these integrations by passing a configuration object to the plugin_config parameter of a TaskEnvironment. When a task is decorated with an environment that has a plugin_config, flyte-sdk uses a specialized TaskTemplate (registered in flyte.extend.TaskPluginRegistry) to handle the setup, execution, and teardown of the distributed cluster.
Apache Spark
When you need to process large datasets using Spark, you can configure a Spark object and pass it to your TaskEnvironment. flyte-sdk automatically initializes a SparkSession for you and makes it available via the Flyte context.
from flyteplugins.spark.task import Spark
import flyte
spark_conf = Spark(
spark_conf={
"spark.driver.memory": "3000M",
"spark.executor.memory": "1000M",
"spark.executor.instances": "2",
},
)
spark_env = flyte.TaskEnvironment(
name="spark_env",
plugin_config=spark_conf,
image="ghcr.io/flyteorg/spark-py:v3.4.0",
)
@spark_env.task
async def process_data(partitions: int = 3) -> float:
# Access the automatically initialized SparkSession
spark = flyte.ctx().data["spark_session"]
# Use Spark as usual
df = spark.range(1, 1000).toDF("number")
return df.count()
Internal Mechanism
The PysparkFunctionTask plugin in plugins/spark/src/flyteplugins/spark/task.py manages the lifecycle:
- Pre-execution: In the
prehook, it builds aSparkSessionusingpyspark.sql.SparkSession.builder. If running in a cluster, it packages your code bundle into a ZIP file and adds it to the Spark context usingsess.sparkContext.addPyFile(file_path)so executors can access your code. - Post-execution: In debug mode (where the action name is
a0), theposthook ensures theSparkSessionis stopped to prevent resource leaks in interactive environments.
Ray
Ray allows you to scale Python applications by distributing tasks and actors. To use Ray, define a RayJobConfig specifying the head and worker node configurations.
from flyteplugins.ray.task import HeadNodeConfig, RayJobConfig, WorkerNodeConfig
import ray
import flyte
ray_config = RayJobConfig(
head_node_config=HeadNodeConfig(ray_start_params={"log-color": "True"}),
worker_node_config=[WorkerNodeConfig(group_name="ray-group", replicas=2)],
runtime_env={"pip": ["numpy", "pandas"]},
)
ray_env = flyte.TaskEnvironment(
name="ray_env",
plugin_config=ray_config,
image="my-ray-image",
)
@ray_env.task
async def ray_task(n: int = 3) -> list[int]:
@ray.remote
def square(x):
return x * x
futures = [square.remote(i) for i in range(n)]
return ray.get(futures)
Internal Mechanism
RayFunctionTask in plugins/ray/src/flyteplugins/ray/task.py handles the integration:
- Initialization: The
prehook callsray.init(). If running in a cluster, it automatically sets theworking_dirin theruntime_envto the current directory, ensuring your code is available on all nodes. - Compatibility: It supports both
runtime_envandruntime_env_yaml(required for KubeRay >= 1.1.0) to maintain compatibility across different cluster versions.
Dask
Dask provides parallel computing for analytics. You can configure a Dask cluster by defining the Scheduler and WorkerGroup resources.
from flyteplugins.dask.task import Dask, Scheduler, WorkerGroup
import flyte
from distributed import Client
dask_config = Dask(
scheduler=Scheduler(resources=flyte.Resources(cpu="1", memory="2Gi")),
workers=WorkerGroup(number_of_workers=3, resources=flyte.Resources(cpu="2", memory="4Gi"))
)
dask_env = flyte.TaskEnvironment(
name="dask_env",
plugin_config=dask_config,
image="my-dask-image",
)
@dask_env.task
async def dask_task():
# Connect to the Dask cluster
client = Client()
# ... perform Dask operations ...
Internal Mechanism
DaskTask in plugins/dask/src/flyteplugins/dask/task.py uses Dask SchedulerPlugin and WorkerPlugin to distribute code:
- Code Distribution: The
prehook registersDownloadCodeBundleWorkerPluginandDownloadCodeBundleSchedulerPlugin. These plugins ensure that the flyte-sdk code bundle is downloaded and added tosys.pathon every node in the Dask cluster before execution begins.
PyTorch Elastic
For distributed deep learning, flyte-sdk integrates with PyTorch Elastic (torch.distributed.run). This is configured using the Elastic class.
from flyteplugins.pytorch.task import Elastic, RunPolicy
import torch
import flyte
torch_env = flyte.TaskEnvironment(
name="torch_env",
plugin_config=Elastic(
nnodes=2, # Number of nodes
nproc_per_node=1, # Processes per node (usually 1 per GPU)
max_restarts=3,
),
image="pytorch-image",
)
@torch_env.task
def train_model(epochs: int):
torch.distributed.init_process_group("gloo")
# ... training logic ...
Resilience and Performance Features
TorchFunctionTask in plugins/pytorch/src/flyteplugins/pytorch/task.py includes several features to improve distributed training stability:
- Zombie Watchdog: It starts a background thread (
_start_zombie_watchdog) to detect a known PyTorch deadlock where the elastic agent hangs if all workers die simultaneously (e.g., during a CUDA OOM). The watchdog force-exits the process if it detects all worker children have become zombies. - NCCL Timeout Management: You can tune failure detection speed using
nccl_collective_timeout_secandnccl_heartbeat_timeout_sec. flyte-sdk propagates these to worker processes via environment variables likeFLYTE_NCCL_COLLECTIVE_TIMEOUT_SEC. - Automatic Threading Limits: If
nproc_per_node > 1, flyte-sdk setsOMP_NUM_THREADS=1by default to prevent CPU over-subscription, matching standardtorchrunbehavior.
Dynamic Cluster Overrides
You can dynamically adjust cluster configurations at call time using the .override() method. This is useful when you want to scale the number of executors based on the input size.
@spark_env.task
async def spark_task():
...
@task_env.task
async def scaling_controller(num_executors: int):
# Create a new config with the desired number of instances
new_conf = Spark(
spark_conf={
"spark.executor.instances": str(num_executors),
"spark.driver.memory": "3000M",
}
)
# Override the task's plugin_config for this specific call
return await spark_task.override(plugin_config=new_conf)()
Note: plugin_config cannot be used if the TaskEnvironment is marked as reusable=True. If you need to override resources or other settings for a task with a plugin, ensure reusable is disabled.