NCCL and Failure Detection
Distributed training in PyTorch often suffers from "silent hangs" where a single worker failure (like a CUDA Out-of-Memory error) causes all other workers to block indefinitely on a collective operation (e.g., all_reduce). By default, these hangs can last up to 30 minutes. flyte-sdk provides the Elastic configuration class to implement a multi-layered failure detection and recovery system that significantly reduces this downtime.
The Elastic Configuration Interface
The Elastic class in flyteplugins.pytorch.task is the primary interface for tuning how flyte-sdk manages PyTorch worker lifecycles. It allows you to configure node counts, process counts, and critical NCCL (NVIDIA Collective Communications Library) parameters.
from flyteplugins.pytorch.task import Elastic
elastic_config = Elastic(
nnodes=1,
nproc_per_node=2,
max_restarts=3,
nccl_collective_timeout_sec=60,
nccl_heartbeat_timeout_sec=60,
nccl_async_error_handling=True
)
Two-Phase Failure Detection
flyte-sdk implements a two-phase detection mechanism to convert stuck network operations into hard process failures that the elastic agent can recover from.
Phase 1: Collective Timeout
When a worker desyncs (e.g., it skips a collective call because it crashed or caught an exception), the surviving workers will wait at the collective operation. The nccl_collective_timeout_sec parameter controls this wait time.
In flyteplugins.pytorch.task.py, flyte-sdk propagates this value via the FLYTE_NCCL_COLLECTIVE_TIMEOUT_SEC environment variable. Because PyTorch binds its default timeout constant at import time, flyte-sdk's launcher_entrypoint manually patches both torch.distributed.constants and torch.distributed.distributed_c10d before the user's task code executes:
# From flyteplugins/pytorch/task.py: launcher_entrypoint
nccl_timeout = os.environ.get("FLYTE_NCCL_COLLECTIVE_TIMEOUT_SEC")
if nccl_timeout is not None:
import torch.distributed.constants
import torch.distributed.distributed_c10d
td = timedelta(seconds=int(nccl_timeout))
torch.distributed.constants.default_pg_nccl_timeout = td
torch.distributed.distributed_c10d.default_pg_nccl_timeout = td
Phase 2: Heartbeat Timeout
If a collective operation times out, the NCCL watchdog aborts the communicator. However, the worker process might still remain alive but stuck. The nccl_heartbeat_timeout_sec (mapped to TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC) defines how long the NCCL monitoring thread waits after a stall before sending a SIGABRT to kill the worker process.
flyte-sdk defaults this to 300 seconds (5 minutes), which is much more aggressive than the PyTorch default of 1800 seconds (30 minutes). This requires nccl_enable_monitoring=True (which sets TORCH_NCCL_ENABLE_MONITORING=1) to be active.
Accelerating Detection with Async Error Handling
Setting nccl_async_error_handling=True enables TORCH_NCCL_ASYNC_ERROR_HANDLING=1. This causes NCCL to abort stuck collectives asynchronously rather than blocking. When a failure is detected, the worker process crash-exits immediately. The elastic agent then detects this exit within the monitor_interval (default 3 seconds), bypassing the need to wait for the full heartbeat timeout.
The Zombie Watchdog
A critical challenge in distributed PyTorch is a known deadlock in the elastic agent: when all workers die from SIGABRT simultaneously (common during NCCL timeouts), the agent can deadlock while trying to acquire a shared semaphore that the dead workers will never release.
To protect against this, flyteplugins.pytorch.task.TorchFunctionTask launches a _start_zombie_watchdog thread. This watchdog periodically inspects /proc to count zombie child processes. If it detects that all worker processes have become zombies, it concludes the agent is deadlocked and force-exits the entire task using os._exit(1).
# Logic from _start_zombie_watchdog in flyteplugins/pytorch/task.py
if len(zombie_pids) >= nproc:
logger.error("Zombie watchdog: %d worker processes are zombies... Force-exiting.", len(zombie_pids))
os._exit(1)
Example: Aggressive Failure Recovery
For jobs where rapid recovery is more important than waiting for transient network issues, you can use aggressive settings. This configuration ensures that a CUDA OOM or rank desync results in a restart or failure within approximately 2 minutes, rather than 30+ minutes.
from flyteplugins.pytorch.task import Elastic
import flytekit as flyte
# Example of aggressive settings to minimize hang time
torch_env = flyte.TaskEnvironment(
name="fast_failure_env",
plugin_config=Elastic(
nproc_per_node=2,
nnodes=1,
max_restarts=0, # Fail immediately on deterministic errors like OOM
nccl_heartbeat_timeout_sec=60,
nccl_async_error_handling=True,
nccl_collective_timeout_sec=60,
),
)
When a failure occurs, flyte-sdk provides specific hints in the error logs based on the exit signal:
- SIGABRT (-6): Usually indicates an NCCL collective timeout or rank desync (often caused by CUDA OOM).
- SIGKILL (-9): Usually indicates the process was killed by the system OOM killer for exceeding resource limits.