Skip to main content

Elastic PyTorch Training

Distributed training in flyte-sdk is powered by the PyTorch plugin, which integrates with torch.distributed.run to manage multi-node execution. By using the Elastic configuration class, you can define cluster topology, enable elastic scaling, and tune failure recovery parameters to prevent common distributed training issues like hangs during CUDA Out-of-Memory (OOM) events.

Configuring the Distributed Cluster

To run a distributed PyTorch task, you define a TaskEnvironment with an Elastic plugin configuration. This configuration tells flyte-sdk how many nodes to provision and how many processes to launch on each node.

from flyteplugins.pytorch.task import Elastic
import flyte
import torch

# Define the environment for distributed training
torch_env = flyte.TaskEnvironment(
name="torch_env",
resources=flyte.Resources(cpu=(1, 2), memory=("1Gi", "2Gi"), gpu="T4:1"),
plugin_config=Elastic(
nnodes=2, # Total number of nodes
nproc_per_node=1, # Processes (GPUs) per node
),
)

@torch_env.task
def train_model(epochs: int):
# flyte-sdk handles the setup; you just initialize the process group
torch.distributed.init_process_group("gloo")
# ... training logic ...

When you call this task, flyte-sdk uses TorchFunctionTask to wrap your function. Internally, it invokes torch.distributed.launcher.api.elastic_launch to manage the worker processes across the allocated nodes.

Elastic Scaling

The nnodes parameter in Elastic supports range strings to enable elastic training. If you provide a range like "2:4", the job can start with as few as 2 nodes and scale up to 4 if resources are available.

plugin_config=Elastic(
nnodes="2:4", # Minimum 2 nodes, maximum 4 nodes
nproc_per_node=1,
rdzv_backend="c10d"
)

The TorchFunctionTask.__post_init__ method parses this string into min_nodes and max_nodes, which are then passed to the underlying elastic launcher.

Failure Handling and NCCL Tuning

Distributed training is sensitive to worker failures. If one worker crashes (e.g., due to a CUDA OOM), other workers may hang indefinitely while waiting for a collective operation (like all_reduce) that will never complete.

flyte-sdk provides several knobs in the Elastic class to detect these failures quickly and restart the worker group.

Preventing Hangs with NCCL Timeouts

By default, PyTorch's NCCL collective timeout is 10 minutes, and the heartbeat timeout is 30 minutes. In a cloud environment, this often leads to long periods of idle resource consumption when a job is actually stuck. You can configure more aggressive timeouts:

plugin_config=Elastic(
nnodes=2,
nproc_per_node=1,
max_restarts=3,
nccl_collective_timeout_sec=60, # Fail if a collective takes > 60s
nccl_heartbeat_timeout_sec=60, # Kill worker if heartbeat stalls > 60s
nccl_async_error_handling=True # Abort stuck collectives asynchronously
)

Internally, TorchFunctionTask.pre propagates these settings via environment variables like TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC and FLYTE_NCCL_COLLECTIVE_TIMEOUT_SEC. The launcher_entrypoint in plugins/pytorch/src/flyteplugins/pytorch/task.py then patches torch.distributed.constants.default_pg_nccl_timeout before your code runs, ensuring that init_process_group() respects your configured limits.

The Zombie Watchdog

A known issue in PyTorch's elastic agent occurs when all workers die simultaneously (e.g., from a SIGABRT triggered by a NCCL timeout). The agent can deadlock while trying to acquire a shared semaphore that the dead workers never released.

To protect against this, TorchFunctionTask starts a "zombie watchdog" thread via _start_zombie_watchdog. This thread monitors the /proc filesystem:

  • It counts the number of child processes in a "zombie" state.
  • If the number of zombies matches nproc_per_node, it assumes the agent is deadlocked.
  • It then calls os._exit() to force-terminate the process, allowing the Flyte platform to detect the failure and potentially retry the task.

Environment and Performance Defaults

flyte-sdk applies several default environment settings to ensure predictable behavior in distributed environments:

  1. PYTHONUNBUFFERED=1: Set automatically in TorchFunctionTask.pre to ensure logs from worker processes are visible immediately, even if a crash occurs before a buffer flush.
  2. OMP_NUM_THREADS=1: If nproc_per_node > 1, flyte-sdk sets this to 1. This prevents CPU oversubscription where multiple worker processes on the same node compete for the same CPU cores, which is the same default behavior as torchrun.
  3. max_restarts: Defaults to 3. If a failure is deterministic (like a model being too large for memory), you should set this to 0 to fail immediately rather than wasting time on restarts.
# Aggressive failure detection for debugging OOM hangs
plugin_config=Elastic(
nproc_per_node=2,
nnodes=1,
max_restarts=0,
nccl_async_error_handling=True,
nccl_collective_timeout_sec=60,
)

[!WARNING] Do not catch torch.cuda.OutOfMemoryError and skip batches in distributed training. This causes the rank that skipped the batch to desync from others during the next collective operation, leading to a hang. Instead, let the process fail and rely on the Elastic configuration to restart the group.