Skip to main content

Configuring Spark Tasks

To execute PySpark tasks natively on Kubernetes using flyte-sdk, you define a Spark configuration and attach it to a TaskEnvironment. The flyte-sdk Spark plugin manages the SparkSession lifecycle and automatically distributes your local code to the Spark executors.

Defining a Spark Task

To create a Spark task, first define a Spark configuration object and pass it as the plugin_config to a TaskEnvironment.

from flyteplugins.spark.task import Spark
import flyte

# Configure Spark settings
spark_conf = Spark(
spark_conf={
"spark.driver.memory": "1000M",
"spark.executor.memory": "1000M",
"spark.executor.cores": "1",
"spark.executor.instances": "2",
"spark.driver.cores": "1",
},
)

# Create an environment that uses the Spark plugin
spark_env = flyte.TaskEnvironment(
name="spark_env",
plugin_config=spark_conf,
image=flyte.Image.from_base("apache/spark-py:v3.4.0"),
)

@spark_env.task
async def hello_spark(partitions: int = 3) -> float:
# Access the managed SparkSession from the Flyte context
spark = flyte.ctx().data["spark_session"]

# Use the session for RDD or DataFrame operations
n = 100 * partitions
count = spark.sparkContext.parallelize(range(1, n + 1), partitions).count()
return float(count)

Key Configuration Attributes

The Spark class in flyteplugins.spark.task supports several configuration options:

  • spark_conf: A dictionary of Spark configuration properties (e.g., memory, cores).
  • hadoop_conf: A dictionary of Hadoop configuration properties.
  • executor_path: Path to the Python binary for PySpark execution (defaults to the current interpreter path).
  • applications_path: Path to the main application file (defaults to the Flyte entrypoint).
  • driver_pod / executor_pod: PodTemplate objects to customize the Kubernetes pods for the driver and executors.

Working with Spark DataFrames

flyte-sdk supports pyspark.sql.DataFrame as a first-class type for task inputs and outputs. You can use Annotated to provide column schema information.

import pyspark
from typing import Annotated, Type, OrderedDict
import collections

def kwtypes(**kwargs) -> OrderedDict[str, Type]:
d = collections.OrderedDict()
for k, v in kwargs.items():
d[k] = v
return d

columns = kwtypes(name=str, age=int)

@spark_env.task
async def sum_of_all_ages(sd: Annotated[pyspark.sql.DataFrame, columns]) -> int:
# The DataFrame is automatically loaded into the Spark context
total_age = sd.groupBy().sum("age").collect()[0][0]
return total_age

Automatic Code Bundling

When running in a cluster, PysparkFunctionTask automatically bundles your local code into a zip file and adds it to the Spark context using sess.sparkContext.addPyFile(). This ensures that your local modules and dependencies are available on all Spark executors without manual configuration.

Overriding Configuration at Runtime

You can override the Spark configuration for a specific task call using the .override() method. This is useful for scaling resources based on input size.

from copy import deepcopy

@spark_env.task
async def dynamic_spark_task(instances: int) -> float:
# Create a copy of the base config and modify it
updated_conf = deepcopy(spark_conf)
updated_conf.spark_conf["spark.executor.instances"] = str(instances)

# Override the plugin configuration for this specific call
return await hello_spark.override(plugin_config=updated_conf)(partitions=instances)

Troubleshooting and Session Management

Managed SparkSession

The PysparkFunctionTask plugin creates a SparkSession named "FlyteSpark" during the pre execution hook.

  • Do not create your own SparkSession: Use flyte.ctx().data["spark_session"] to ensure you are using the session managed by the plugin.
  • Session Lifecycle: In standard execution, the Spark operator manages the session lifecycle. In debug mode (when the action name is "a0"), the plugin explicitly stops the session in the post hook to clean up resources.

Local Execution

When running locally, flyte-sdk initializes a local Spark session. Ensure pyspark is installed in your local environment to test Spark tasks without a Kubernetes cluster.

if __name__ == "__main__":
flyte.init_from_config()
# This will run locally using a local Spark session
result = flyte.run(hello_spark, partitions=2)
result.wait()