Configuring Spark Tasks
To execute PySpark tasks natively on Kubernetes using flyte-sdk, you define a Spark configuration and attach it to a TaskEnvironment. The flyte-sdk Spark plugin manages the SparkSession lifecycle and automatically distributes your local code to the Spark executors.
Defining a Spark Task
To create a Spark task, first define a Spark configuration object and pass it as the plugin_config to a TaskEnvironment.
from flyteplugins.spark.task import Spark
import flyte
# Configure Spark settings
spark_conf = Spark(
spark_conf={
"spark.driver.memory": "1000M",
"spark.executor.memory": "1000M",
"spark.executor.cores": "1",
"spark.executor.instances": "2",
"spark.driver.cores": "1",
},
)
# Create an environment that uses the Spark plugin
spark_env = flyte.TaskEnvironment(
name="spark_env",
plugin_config=spark_conf,
image=flyte.Image.from_base("apache/spark-py:v3.4.0"),
)
@spark_env.task
async def hello_spark(partitions: int = 3) -> float:
# Access the managed SparkSession from the Flyte context
spark = flyte.ctx().data["spark_session"]
# Use the session for RDD or DataFrame operations
n = 100 * partitions
count = spark.sparkContext.parallelize(range(1, n + 1), partitions).count()
return float(count)
Key Configuration Attributes
The Spark class in flyteplugins.spark.task supports several configuration options:
spark_conf: A dictionary of Spark configuration properties (e.g., memory, cores).hadoop_conf: A dictionary of Hadoop configuration properties.executor_path: Path to the Python binary for PySpark execution (defaults to the current interpreter path).applications_path: Path to the main application file (defaults to the Flyte entrypoint).driver_pod/executor_pod:PodTemplateobjects to customize the Kubernetes pods for the driver and executors.
Working with Spark DataFrames
flyte-sdk supports pyspark.sql.DataFrame as a first-class type for task inputs and outputs. You can use Annotated to provide column schema information.
import pyspark
from typing import Annotated, Type, OrderedDict
import collections
def kwtypes(**kwargs) -> OrderedDict[str, Type]:
d = collections.OrderedDict()
for k, v in kwargs.items():
d[k] = v
return d
columns = kwtypes(name=str, age=int)
@spark_env.task
async def sum_of_all_ages(sd: Annotated[pyspark.sql.DataFrame, columns]) -> int:
# The DataFrame is automatically loaded into the Spark context
total_age = sd.groupBy().sum("age").collect()[0][0]
return total_age
Automatic Code Bundling
When running in a cluster, PysparkFunctionTask automatically bundles your local code into a zip file and adds it to the Spark context using sess.sparkContext.addPyFile(). This ensures that your local modules and dependencies are available on all Spark executors without manual configuration.
Overriding Configuration at Runtime
You can override the Spark configuration for a specific task call using the .override() method. This is useful for scaling resources based on input size.
from copy import deepcopy
@spark_env.task
async def dynamic_spark_task(instances: int) -> float:
# Create a copy of the base config and modify it
updated_conf = deepcopy(spark_conf)
updated_conf.spark_conf["spark.executor.instances"] = str(instances)
# Override the plugin configuration for this specific call
return await hello_spark.override(plugin_config=updated_conf)(partitions=instances)
Troubleshooting and Session Management
Managed SparkSession
The PysparkFunctionTask plugin creates a SparkSession named "FlyteSpark" during the pre execution hook.
- Do not create your own SparkSession: Use
flyte.ctx().data["spark_session"]to ensure you are using the session managed by the plugin. - Session Lifecycle: In standard execution, the Spark operator manages the session lifecycle. In debug mode (when the action name is "a0"), the plugin explicitly stops the session in the
posthook to clean up resources.
Local Execution
When running locally, flyte-sdk initializes a local Spark session. Ensure pyspark is installed in your local environment to test Spark tasks without a Kubernetes cluster.
if __name__ == "__main__":
flyte.init_from_config()
# This will run locally using a local Spark session
result = flyte.run(hello_spark, partitions=2)
result.wait()