Tutorial: Your First Ray Task
You can run distributed Ray jobs on Flyte without managing your own cluster infrastructure. The flyte-sdk provides a Ray plugin that handles the lifecycle of a transient Ray cluster, automatically initializing the environment and cleaning up resources after your task completes.
By the end of this tutorial, you will have a Flyte task that spins up a Ray head node and multiple worker nodes to perform parallel computations.
Prerequisites
To follow this tutorial, ensure you have the following installed:
flyte-sdkrayflyteplugins-ray
pip install flyte-sdk ray flyteplugins-ray
Step 1: Define the Ray Cluster Configuration
The first step is to describe the Ray cluster you want flyte-sdk to provision. You do this using RayJobConfig, which allows you to specify the resources for both the head node and the worker groups.
from flyteplugins.ray.task import HeadNodeConfig, RayJobConfig, WorkerNodeConfig
# Configure the Ray cluster
ray_config = RayJobConfig(
head_node_config=HeadNodeConfig(
ray_start_params={"log-color": "True"}
),
worker_node_config=[
WorkerNodeConfig(
group_name="ray-worker-group",
replicas=2,
min_replicas=1,
max_replicas=2,
)
],
runtime_env={"pip": ["numpy", "pandas"]},
shutdown_after_job_finishes=True,
ttl_seconds_after_finished=300,
)
In this configuration:
HeadNodeConfig: Defines parameters for the Ray head node.WorkerNodeConfig: Defines a group of worker nodes. You can specify the number ofreplicasand set limits for autoscaling.runtime_env: Specifies dependencies that Ray should install on all nodes in the cluster.shutdown_after_job_finishes: Ensures the transient cluster is destroyed once the Flyte task is done.
Step 2: Create a Task Environment
In flyte-sdk, plugins are attached to a TaskEnvironment. This environment tells Flyte to use the RayFunctionTask plugin when executing tasks decorated with it.
import flyte
# Create the environment with the Ray configuration
ray_env = flyte.TaskEnvironment(
name="ray_env",
plugin_config=ray_config,
image="ghcr.io/flyteorg/flytekit:py3.11-latest", # Use an image with ray installed
resources=flyte.Resources(cpu=(2, 4), memory=("2Gi", "4Gi")),
)
The plugin_config parameter is where you pass the RayJobConfig created in the previous step. Note that Ray tasks in flyte-sdk cannot be used with "reusable" environments; each task execution typically manages its own transient cluster lifecycle.
Step 3: Define the Distributed Task
Now you can define your Ray functions and the Flyte task that orchestrates them. The Flyte task must be async because RayFunctionTask inherits from AsyncFunctionTaskTemplate.
import ray
import typing
# Define a standard Ray remote function
@ray.remote
def square(x: int) -> int:
return x * x
# Define the Flyte task using the ray_env decorator
@ray_env.task
async def run_ray_workflow(n: int) -> typing.List[int]:
# flyte-sdk automatically calls ray.init() before this code runs
# using the configuration provided in ray_env.
futures = [square.remote(i) for i in range(n)]
results = ray.get(futures)
return results
When run_ray_workflow is executed, the RayFunctionTask.pre method is triggered. It checks if a Ray cluster is already initialized and, if not, calls ray.init() with the parameters derived from your RayJobConfig. If running inside a Flyte cluster, it also automatically configures the working_dir in the Ray runtime_env.
Step 4: Run the Task
You can run this task locally or on a Flyte cluster. When running locally, the plugin will initialize a local Ray instance if one isn't already running.
if __name__ == "__main__":
# Initialize flyte-sdk (local or remote)
flyte.init_from_config()
# Execute the task
result = flyte.run(run_ray_workflow(n=5))
print(f"Results: {result}")
Complete Example
Here is the full code combining all the steps:
import ray
import typing
import flyte
from flyteplugins.ray.task import HeadNodeConfig, RayJobConfig, WorkerNodeConfig
# 1. Configuration
ray_config = RayJobConfig(
head_node_config=HeadNodeConfig(ray_start_params={"log-color": "True"}),
worker_node_config=[WorkerNodeConfig(group_name="ray-group", replicas=2)],
runtime_env={"pip": ["numpy"]},
shutdown_after_job_finishes=True,
)
# 2. Environment
ray_env = flyte.TaskEnvironment(
name="ray_env",
plugin_config=ray_config,
)
# 3. Ray Functions
@ray.remote
def f(x):
return x * x
# 4. Flyte Task
@ray_env.task
async def hello_ray(n: int = 3) -> typing.List[int]:
futures = [f.remote(i) for i in range(n)]
return ray.get(futures)
if __name__ == "__main__":
print(flyte.run(hello_ray(n=5)))
Next Steps
- Existing Clusters: If you have a long-running Ray cluster, you can connect to it by providing the
addressparameter in yourRayJobConfig. - Resource Management: Use
HeadNodeConfigandWorkerNodeConfigto specify exact Kubernetespod_templateor resource requests/limits for fine-grained control over your Ray nodes.