Skip to main content

Checkpointing & State Recovery

Checkpointing allows a Flyte task to save its progress and resume from that state if the task is interrupted or retried. This is particularly useful for long-running training jobs or iterative data processing where you don't want to restart from scratch after a transient failure.

In flyte-sdk, the Checkpoint helper manages a local temporary workspace, handles the download and extraction of previous state, and automates the upload of new state to remote storage.

Accessing the Checkpoint Helper

You access the checkpoint helper through the task context using flyte.ctx().checkpoint. This object is an instance of flyte.Checkpoint and is automatically configured with the remote paths provided by the Flyte platform.

import flyte

@flyte.task(retries=3)
def my_resilient_task():
checkpoint = flyte.ctx().checkpoint
# ...

Restoring Previous State

At the beginning of your task, you should attempt to load any state saved by a previous attempt. The load_sync() method (or load() for async tasks) returns a pathlib.Path to the restored data if it exists, or None if this is the first attempt or no checkpoint was found.

import flyte
import pathlib

@flyte.task(retries=3)
def process_data():
checkpoint = flyte.ctx().checkpoint

# load_sync() returns a Path to the restored state or None
prev_path: pathlib.Path | None = checkpoint.load_sync()

start_index = 0
if prev_path:
# If we saved bytes or a single file, it appears at prev_path / "payload"
# If we saved a directory, prev_path is the root of the extracted tree
state_file = prev_path / "payload"
if state_file.exists():
start_index = int(state_file.read_text())
print(f"Resuming from index {start_index}")

Saving Progress

You can save state as bytes, a single file, or an entire directory. Use save_sync() (or save() for async) to upload the state to the remote destination.

  • Bytes: Uploaded directly as a blob.
  • File: The file is uploaded directly.
  • Directory: The directory is automatically compressed into a .tar.gz archive before upload.
    # ... inside the task loop ...
for i in range(start_index, 100):
# Perform work...

# Save progress as bytes
checkpoint.save_sync(str(i + 1).encode())

Working with Directories and Frameworks

When working with machine learning frameworks like PyTorch Lightning, you often save a directory of checkpoints. flyte-sdk provides a latest_checkpoint utility to find the most recent file in a restored directory tree.

import flyte
import pathlib

@flyte.task(retries=3)
def train_model():
checkpoint = flyte.ctx().checkpoint
ckpt_dir = pathlib.Path("my_checkpoints")
ckpt_dir.mkdir(exist_ok=True)

# Restore previous directory if it exists
prev_root = checkpoint.load_sync()
resume_from = None
if prev_root:
# Find the newest 'last.ckpt' in the restored tree
last = flyte.latest_checkpoint(prev_root, glob_pattern="**/last.ckpt")
if last:
resume_from = str(last)

# ... training logic using resume_from ...

# Save the entire directory for the next attempt
checkpoint.save_sync(ckpt_dir)

Complete Example: Resilient Counter

This complete example demonstrates a task that increments a counter and fails intentionally to show how it resumes from the last saved index.

import flyte
import logging

# Configure a task environment with retries
env = flyte.TaskEnvironment(name="checkpoint-tutorial")

@env.task(retries=3)
def resilient_counter(n_iterations: int) -> int:
checkpoint = flyte.ctx().checkpoint

# 1. Restore state
restored_path = checkpoint.load_sync()
start = 0
if restored_path:
# Single-file/bytes checkpoints are restored to a file named 'payload'
payload = restored_path / "payload"
if payload.exists():
start = int(payload.read_bytes().decode())
print(f"Restoring from iteration {start}")

# 2. Process and Save
for index in range(start, n_iterations):
# Simulate a failure halfway through
if index == n_iterations // 2 and flyte.ctx().attempt_number() == 0:
raise RuntimeError("Simulated failure on first attempt!")

print(f"Processing iteration {index}")

# Save the next starting point
checkpoint.save_sync(str(index + 1).encode())

return n_iterations

if __name__ == "__main__":
# Local execution will simulate the checkpointing behavior
flyte.init_from_config(log_level=logging.DEBUG)
result = resilient_counter(n_iterations=10)
print(f"Final result: {result}")

Important Considerations

  1. Temporary Workspace: The Checkpoint object manages its own temporary directory (accessible via checkpoint.path). This directory is cleaned up when the Checkpoint object is destroyed.
  2. Payload File: When you save bytes or a single file, flyte-sdk restores it to a file named payload inside the directory returned by load_sync().
  3. Tarball Extraction: If you save a directory, flyte-sdk tars it. Upon load_sync(), it extracts the contents directly into the workspace, and load_sync() returns the path to that workspace root.
  4. Async Support: If your task is defined as async def, use the awaitable await checkpoint.load() and await checkpoint.save() methods instead of their sync counterparts.