Custom Task Templates
Custom task templates in flyte-sdk allow you to extend the platform by defining how tasks are configured, serialized, and executed. Whether you are building a connector to an external service like BigQuery or creating a specialized execution environment like Ray or Spark, you use task templates to bridge the gap between Python code and the Flyte backend.
Task Template Fundamentals
The TaskTemplate class in flyte.extend is the base abstraction for all tasks. It encapsulates metadata such as the task name, interface (inputs and outputs), resource requirements, and caching policies.
When you define a custom task template, you typically interact with these key methods:
custom_config(sctx): Returns a dictionary that is serialized into the task's custom configuration. This is used by backend plugins to understand task-specific settings.execute(*args, **kwargs): Defines the logic for running the task. For connector-style tasks, this often involves calling an external API.forward(*args, **kwargs): Defines how the task behaves during local execution (when not running on a Flyte cluster).
Implementing Connector Tasks
Connector tasks are used for operations that do not require a user-defined Python function to run on the cluster, such as executing a SQL query or triggering a batch job in another system.
To implement a connector task, inherit from TaskTemplate. You can also use the AsyncConnectorExecutorMixin from flyte.connectors to simplify local execution by routing it through a registered AsyncConnector.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Type
from flyte.extend import TaskTemplate
from flyte.models import NativeInterface, SerializationContext
from flyte.connectors import AsyncConnectorExecutorMixin
@dataclass
class BatchJobConfig:
timeout_seconds: int = 300
class BatchJobTask(AsyncConnectorExecutorMixin, TaskTemplate):
_TASK_TYPE = "batch_job"
def __init__(
self,
name: str,
plugin_config: BatchJobConfig,
inputs: Optional[Dict[str, Type]] = None,
outputs: Optional[Dict[str, Type]] = None,
**kwargs,
):
super().__init__(
name=name,
interface=NativeInterface(
{k: (v, None) for k, v in inputs.items()} if inputs else {},
outputs or {},
),
task_type=self._TASK_TYPE,
image=None, # Connectors often don't need a custom image
**kwargs,
)
self.plugin_config = plugin_config
def custom_config(self, sctx: SerializationContext) -> Optional[Dict[str, Any]]:
# This dictionary is passed to the backend plugin
return {"timeout_seconds": self.plugin_config.timeout_seconds}
In this example, BatchJobTask defines a custom task type "batch_job". When this task is executed, the Flyte backend uses the custom_config to configure the execution, while AsyncConnectorExecutorMixin handles local testing by looking up a corresponding connector in the ConnectorRegistry.
Implementing Function-Based Plugins
If you want to wrap standard Python functions but provide specialized backend configuration (e.g., running a function on a Ray cluster), inherit from AsyncFunctionTaskTemplate.
This class handles the complexities of serializing function arguments, resolving code bundles, and managing execution contexts. You register these templates using the TaskPluginRegistry.
from dataclasses import dataclass
from typing import Any
from flyte.extend import AsyncFunctionTaskTemplate, TaskPluginRegistry
from flyte.models import SerializationContext
@dataclass
class EchoConfig:
"""Configuration for the echo plugin."""
verbose: bool = False
@dataclass(kw_only=True)
class EchoTask(AsyncFunctionTaskTemplate):
plugin_config: EchoConfig
task_type: str = "echo"
def custom_config(self, sctx: SerializationContext) -> dict[str, Any]:
# Pass plugin-specific settings to the backend
return {"verbose": self.plugin_config.verbose}
# Register the config type so the @task decorator knows to use EchoTask
TaskPluginRegistry.register(EchoConfig, EchoTask)
Once registered, any task decorated with @task that receives an EchoConfig object in its configuration will be instantiated as an EchoTask.
Dynamic Task Overrides
The TaskTemplate.override() method allows you to modify task parameters at the call-site. This is useful for adjusting resources or timeouts dynamically based on workflow logic.
@task(requests=Resources(cpu="1", memory="1Gi"))
def my_task(x: int) -> int:
return x + 1
# Override resources at the call-site
custom_task = my_task.override(resources=Resources(cpu="2", memory="2Gi"))
Override Constraints
The flyte-sdk enforces several rules when using override():
- Immutable Fields: You cannot override
name,image,docs, orinterface. Attempting to do so raises aValueError. - Reusability Conflicts: If a task is marked as
reusable, you cannot overrideresources,env_vars, orsecrets. Reusable tasks are designed to run in a shared environment and must inherit these configurations from their parent environment. To override these, you must first disable reusability by passingreusable="off".
Internal Execution Flow
When a task is invoked, TaskTemplate.__call__ determines the execution context:
- Task Context: If running inside a Flyte execution (e.g., a workflow), the task is submitted to the internal controller via
controller.submit(). - Local Context: If running as a standalone Python script, it calls
forward().
For AsyncFunctionTaskTemplate, the execute() method manages the lifecycle:
- It calls
pre()to set up the environment. - It executes the function. If the function is synchronous, it uses
run_sync_with_loopto ensure it runs correctly within the asynchronous Flyte runtime. - It calls
post()with the return values before returning the result.
This structure ensures that custom logic in pre and post hooks is consistently applied regardless of whether the underlying user function is sync or async.