Skip to main content

NCCL and Failure Detection

Distributed training in PyTorch often suffers from "silent hangs" where a single worker failure (like a CUDA Out-of-Memory error) causes all other workers to block indefinitely on a collective operation (e.g., all_reduce). By default, these hangs can last up to 30 minutes. flyte-sdk provides the Elastic configuration class to implement a multi-layered failure detection and recovery system that significantly reduces this downtime.

The Elastic Configuration Interface

The Elastic class in flyteplugins.pytorch.task is the primary interface for tuning how flyte-sdk manages PyTorch worker lifecycles. It allows you to configure node counts, process counts, and critical NCCL (NVIDIA Collective Communications Library) parameters.

from flyteplugins.pytorch.task import Elastic

elastic_config = Elastic(
nnodes=1,
nproc_per_node=2,
max_restarts=3,
nccl_collective_timeout_sec=60,
nccl_heartbeat_timeout_sec=60,
nccl_async_error_handling=True
)

Two-Phase Failure Detection

flyte-sdk implements a two-phase detection mechanism to convert stuck network operations into hard process failures that the elastic agent can recover from.

Phase 1: Collective Timeout

When a worker desyncs (e.g., it skips a collective call because it crashed or caught an exception), the surviving workers will wait at the collective operation. The nccl_collective_timeout_sec parameter controls this wait time.

In flyteplugins.pytorch.task.py, flyte-sdk propagates this value via the FLYTE_NCCL_COLLECTIVE_TIMEOUT_SEC environment variable. Because PyTorch binds its default timeout constant at import time, flyte-sdk's launcher_entrypoint manually patches both torch.distributed.constants and torch.distributed.distributed_c10d before the user's task code executes:

# From flyteplugins/pytorch/task.py: launcher_entrypoint
nccl_timeout = os.environ.get("FLYTE_NCCL_COLLECTIVE_TIMEOUT_SEC")
if nccl_timeout is not None:
import torch.distributed.constants
import torch.distributed.distributed_c10d
td = timedelta(seconds=int(nccl_timeout))
torch.distributed.constants.default_pg_nccl_timeout = td
torch.distributed.distributed_c10d.default_pg_nccl_timeout = td

Phase 2: Heartbeat Timeout

If a collective operation times out, the NCCL watchdog aborts the communicator. However, the worker process might still remain alive but stuck. The nccl_heartbeat_timeout_sec (mapped to TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC) defines how long the NCCL monitoring thread waits after a stall before sending a SIGABRT to kill the worker process.

flyte-sdk defaults this to 300 seconds (5 minutes), which is much more aggressive than the PyTorch default of 1800 seconds (30 minutes). This requires nccl_enable_monitoring=True (which sets TORCH_NCCL_ENABLE_MONITORING=1) to be active.

Accelerating Detection with Async Error Handling

Setting nccl_async_error_handling=True enables TORCH_NCCL_ASYNC_ERROR_HANDLING=1. This causes NCCL to abort stuck collectives asynchronously rather than blocking. When a failure is detected, the worker process crash-exits immediately. The elastic agent then detects this exit within the monitor_interval (default 3 seconds), bypassing the need to wait for the full heartbeat timeout.

The Zombie Watchdog

A critical challenge in distributed PyTorch is a known deadlock in the elastic agent: when all workers die from SIGABRT simultaneously (common during NCCL timeouts), the agent can deadlock while trying to acquire a shared semaphore that the dead workers will never release.

To protect against this, flyteplugins.pytorch.task.TorchFunctionTask launches a _start_zombie_watchdog thread. This watchdog periodically inspects /proc to count zombie child processes. If it detects that all worker processes have become zombies, it concludes the agent is deadlocked and force-exits the entire task using os._exit(1).

# Logic from _start_zombie_watchdog in flyteplugins/pytorch/task.py
if len(zombie_pids) >= nproc:
logger.error("Zombie watchdog: %d worker processes are zombies... Force-exiting.", len(zombie_pids))
os._exit(1)

Example: Aggressive Failure Recovery

For jobs where rapid recovery is more important than waiting for transient network issues, you can use aggressive settings. This configuration ensures that a CUDA OOM or rank desync results in a restart or failure within approximately 2 minutes, rather than 30+ minutes.

from flyteplugins.pytorch.task import Elastic
import flytekit as flyte

# Example of aggressive settings to minimize hang time
torch_env = flyte.TaskEnvironment(
name="fast_failure_env",
plugin_config=Elastic(
nproc_per_node=2,
nnodes=1,
max_restarts=0, # Fail immediately on deterministic errors like OOM
nccl_heartbeat_timeout_sec=60,
nccl_async_error_handling=True,
nccl_collective_timeout_sec=60,
),
)

When a failure occurs, flyte-sdk provides specific hints in the error logs based on the exit signal:

  • SIGABRT (-6): Usually indicates an NCCL collective timeout or rank desync (often caused by CUDA OOM).
  • SIGKILL (-9): Usually indicates the process was killed by the system OOM killer for exceeding resource limits.