Extending DataFrame Support
To add support for a new dataframe library in flyte-sdk, you must implement specialized encoders and decoders and register them with the DataFrameTransformerEngine. This allows the Flyte type engine to automatically handle your custom dataframe types in task inputs and outputs.
Implement a Custom Encoder
A DataFrameEncoder is responsible for taking a Python dataframe instance and persisting it to storage (like S3 or GCS) as a Flyte StructuredDataset.
In the following example from the Polars plugin, the encoder handles polars.DataFrame by writing it to a Parquet file:
import os
from pathlib import Path
import polars as pl
import flyte.storage as storage
from flyte.io._dataframe.dataframe import PARQUET, DataFrame
from flyte.io.extend import DataFrameEncoder
from flyteidl2.core import literals_pb2, types_pb2
class PolarsToParquetEncodingHandler(DataFrameEncoder):
def __init__(self):
# Register for pl.DataFrame, all protocols (None), and PARQUET format
super().__init__(pl.DataFrame, None, PARQUET)
async def encode(
self,
dataframe: DataFrame,
structured_dataset_type: types_pb2.StructuredDatasetType,
) -> literals_pb2.StructuredDataset:
# 1. Determine the destination URI
if not dataframe.uri:
from flyte._context import internal_ctx
ctx = internal_ctx()
uri = str(ctx.raw_data.get_random_remote_path())
else:
uri = typing.cast(str, dataframe.uri)
# 2. Prepare the local or remote path
if not storage.is_remote(uri):
Path(uri).mkdir(parents=True, exist_ok=True)
path = os.path.join(uri, f"{0:05}")
# 3. Persist the dataframe using the library's native methods
df = typing.cast(pl.DataFrame, dataframe.val)
df.write_parquet(path)
# 4. Return the StructuredDataset literal
structured_dataset_type.format = PARQUET
return literals_pb2.StructuredDataset(
uri=uri,
metadata=literals_pb2.StructuredDatasetMetadata(
structured_dataset_type=structured_dataset_type
)
)
Implement a Custom Decoder
A DataFrameDecoder performs the inverse operation: it reads data from a URI and returns a native Python dataframe instance.
from flyte.io.extend import DataFrameDecoder
class ParquetToPolarsDecodingHandler(DataFrameDecoder):
def __init__(self):
super().__init__(pl.DataFrame, None, PARQUET)
async def decode(
self,
flyte_value: literals_pb2.StructuredDataset,
current_task_metadata: literals_pb2.StructuredDatasetMetadata,
) -> pl.DataFrame:
uri = flyte_value.uri
# Handle column subsetting if requested in the task signature
columns = None
if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns:
columns = [c.name for c in current_task_metadata.structured_dataset_type.columns]
parquet_path = os.path.join(uri, f"{0:05}")
# Return the native dataframe type
return pl.read_parquet(parquet_path, columns=columns)
Register Handlers with the Engine
Once implemented, you must register your handlers with the DataFrameTransformerEngine. This is typically done at module import time or within a plugin initialization function.
from flyte.io.extend import DataFrameTransformerEngine
# Register the encoder and decoder
DataFrameTransformerEngine.register(
PolarsToParquetEncodingHandler(),
default_format_for_type=True
)
DataFrameTransformerEngine.register(
ParquetToPolarsDecodingHandler(),
default_format_for_type=True
)
The register method accepts several configuration flags:
default_for_type: IfTrue, this handler becomes the default for the Python type across all protocols and formats.default_format_for_type: IfTrue, this handler's format (e.g., "parquet") becomes the default for the Python type.override: IfTrue, this registration will replace any existing handler for the same type/protocol/format combination.
Protocol and Format Matching
The DataFrameTransformerEngine uses a hierarchical lookup to find the correct handler:
- Exact Match: It first looks for a handler matching the exact
(Type, Protocol, Format). - Protocol Fallback: If no exact match exists, it looks for a handler registered with the
fsspecprotocol for that format. - Generic Format Fallback: If still not found, it looks for handlers with an empty string (
"") as the format, which signifies a generic handler.
Troubleshooting URI Generation
When implementing encode, if the incoming dataframe.uri is None, you must generate a unique path. flyte-sdk provides a helper via the internal context:
from flyte._context import internal_ctx
ctx = internal_ctx()
# Generates a path like s3://my-bucket/flyte/metadata/raw/...
uri = str(ctx.raw_data.get_random_remote_path())
Failure to provide a unique URI can result in data being overwritten if multiple tasks or multiple outputs from the same task use the same default location.