Checkpointing & State Recovery
Checkpointing allows a Flyte task to save its progress and resume from that state if the task is interrupted or retried. This is particularly useful for long-running training jobs or iterative data processing where you don't want to restart from scratch after a transient failure.
In flyte-sdk, the Checkpoint helper manages a local temporary workspace, handles the download and extraction of previous state, and automates the upload of new state to remote storage.
Accessing the Checkpoint Helper
You access the checkpoint helper through the task context using flyte.ctx().checkpoint. This object is an instance of flyte.Checkpoint and is automatically configured with the remote paths provided by the Flyte platform.
import flyte
@flyte.task(retries=3)
def my_resilient_task():
checkpoint = flyte.ctx().checkpoint
# ...
Restoring Previous State
At the beginning of your task, you should attempt to load any state saved by a previous attempt. The load_sync() method (or load() for async tasks) returns a pathlib.Path to the restored data if it exists, or None if this is the first attempt or no checkpoint was found.
import flyte
import pathlib
@flyte.task(retries=3)
def process_data():
checkpoint = flyte.ctx().checkpoint
# load_sync() returns a Path to the restored state or None
prev_path: pathlib.Path | None = checkpoint.load_sync()
start_index = 0
if prev_path:
# If we saved bytes or a single file, it appears at prev_path / "payload"
# If we saved a directory, prev_path is the root of the extracted tree
state_file = prev_path / "payload"
if state_file.exists():
start_index = int(state_file.read_text())
print(f"Resuming from index {start_index}")
Saving Progress
You can save state as bytes, a single file, or an entire directory. Use save_sync() (or save() for async) to upload the state to the remote destination.
- Bytes: Uploaded directly as a blob.
- File: The file is uploaded directly.
- Directory: The directory is automatically compressed into a
.tar.gzarchive before upload.
# ... inside the task loop ...
for i in range(start_index, 100):
# Perform work...
# Save progress as bytes
checkpoint.save_sync(str(i + 1).encode())
Working with Directories and Frameworks
When working with machine learning frameworks like PyTorch Lightning, you often save a directory of checkpoints. flyte-sdk provides a latest_checkpoint utility to find the most recent file in a restored directory tree.
import flyte
import pathlib
@flyte.task(retries=3)
def train_model():
checkpoint = flyte.ctx().checkpoint
ckpt_dir = pathlib.Path("my_checkpoints")
ckpt_dir.mkdir(exist_ok=True)
# Restore previous directory if it exists
prev_root = checkpoint.load_sync()
resume_from = None
if prev_root:
# Find the newest 'last.ckpt' in the restored tree
last = flyte.latest_checkpoint(prev_root, glob_pattern="**/last.ckpt")
if last:
resume_from = str(last)
# ... training logic using resume_from ...
# Save the entire directory for the next attempt
checkpoint.save_sync(ckpt_dir)
Complete Example: Resilient Counter
This complete example demonstrates a task that increments a counter and fails intentionally to show how it resumes from the last saved index.
import flyte
import logging
# Configure a task environment with retries
env = flyte.TaskEnvironment(name="checkpoint-tutorial")
@env.task(retries=3)
def resilient_counter(n_iterations: int) -> int:
checkpoint = flyte.ctx().checkpoint
# 1. Restore state
restored_path = checkpoint.load_sync()
start = 0
if restored_path:
# Single-file/bytes checkpoints are restored to a file named 'payload'
payload = restored_path / "payload"
if payload.exists():
start = int(payload.read_bytes().decode())
print(f"Restoring from iteration {start}")
# 2. Process and Save
for index in range(start, n_iterations):
# Simulate a failure halfway through
if index == n_iterations // 2 and flyte.ctx().attempt_number() == 0:
raise RuntimeError("Simulated failure on first attempt!")
print(f"Processing iteration {index}")
# Save the next starting point
checkpoint.save_sync(str(index + 1).encode())
return n_iterations
if __name__ == "__main__":
# Local execution will simulate the checkpointing behavior
flyte.init_from_config(log_level=logging.DEBUG)
result = resilient_counter(n_iterations=10)
print(f"Final result: {result}")
Important Considerations
- Temporary Workspace: The
Checkpointobject manages its own temporary directory (accessible viacheckpoint.path). This directory is cleaned up when theCheckpointobject is destroyed. - Payload File: When you save
bytesor a single file,flyte-sdkrestores it to a file namedpayloadinside the directory returned byload_sync(). - Tarball Extraction: If you save a directory,
flyte-sdktars it. Uponload_sync(), it extracts the contents directly into the workspace, andload_sync()returns the path to that workspace root. - Async Support: If your task is defined as
async def, use the awaitableawait checkpoint.load()andawait checkpoint.save()methods instead of their sync counterparts.