Skip to main content

High-Performance Model Loading

When deploying large machine learning models (LLMs) in serving environments like vLLM or SGLang, the time spent downloading multi-gigabyte weight files from object storage to local disk can significantly delay container startup. The flyte-sdk provides a high-performance model loading utility centered around the SafeTensorsStreamer class, which streams weights directly from S3 or GCS into memory, bypassing the need for full local file downloads.

Streaming Weights with SafeTensorsStreamer

The SafeTensorsStreamer is designed to iterate over model weights and load them into PyTorch tensors as they are downloaded. This is particularly useful in custom model loaders where you want to populate a model's state dict without waiting for the entire model to be present on disk.

To use the streamer, initialize it with the remote and local paths and iterate over the get_tensors() generator:

from flyte.app.extras._model_loader.loader import SafeTensorsStreamer

# Initialize the streamer
streamer = SafeTensorsStreamer(
remote_path="s3://my-bucket/models/llama-7b/",
local_path="/srv/model"
)

# Stream tensors directly into memory
for name, tensor in streamer.get_tensors():
# 'name' is the tensor key (e.g., 'model.layers.0.self_attn.q_proj.weight')
# 'tensor' is a torch.Tensor ready for use
print(f"Loaded {name} with shape {tensor.shape}")

Internally, SafeTensorsStreamer uses asyncio.Runner to wrap its asynchronous implementation, providing a clean synchronous generator interface for standard Python code.

Implementation Details

The efficiency of the streamer comes from its ability to perform parallel range requests and parse the SafeTensors format on the fly.

Header Parsing and Metadata

Before downloading weights, the streamer must know where each tensor starts and ends within the file. It performs two targeted range requests using obstore.get_range_async:

  1. It reads the first 8 bytes to determine the header_size.
  2. It reads the next header_size bytes to parse the JSON metadata.

This metadata is encapsulated in SafeTensorsMetadata and TensorMetadata classes (defined in flyte.app.extras._model_loader.loader). These classes store the shape, dtype, and data_offsets for every tensor in the file.

Parallel Chunked Downloads

The streamer utilizes ObstoreParallelReader (from flyte.storage._parallel_reader) to manage the download process. It breaks down each tensor into smaller chunks (defined by CHUNK_SIZE) and fetches them concurrently using a pool of workers.

Zero-Copy Loading

Once all chunks for a specific tensor are downloaded into a memory buffer, the streamer uses torch.frombuffer to create a tensor without copying the underlying data:

# Internal logic in SafeTensorsStreamer._get_tensors_async
return torch.frombuffer(
await buf.read(),
dtype=source.metadata.dtype,
count=len(source.metadata),
offset=0,
).view(source.metadata.shape)

Handling Sharded Models and Tensor Parallelism

Large models are often sharded across multiple files. SafeTensorsStreamer supports two methods for locating these shards:

  1. Index-based: If a model.safetensors.index.json file exists in the remote path, the streamer parses it to map tensor names to specific files.
  2. Pattern-based: If no index is found, or if tensor_parallel_size > 1 is specified, it falls back to pattern matching.

When using tensor parallelism, you can specify the rank and tensor_parallel_size to load only the shards relevant to a specific GPU:

streamer = SafeTensorsStreamer(
remote_path=REMOTE_PATH,
local_path=LOCAL_PATH,
rank=0,
tensor_parallel_size=2
)

In this mode, the streamer looks for files matching the pattern model-rank-{rank}-part-*.safetensors.

Prefetching and Configuration

While weights are streamed, other artifacts like config.json or tokenizer.json are still needed on disk by most model loaders. The prefetch function in flyte.app.extras._model_loader.loader handles downloading these non-weight files while optionally excluding the large .safetensors files.

The behavior of the model loader is controlled by several environment variables:

VariableDefaultDescription
FLYTE_MODEL_LOADER_REMOTE_MODEL_PATHNoneThe S3/GCS path to the model.
FLYTE_MODEL_LOADER_LOCAL_MODEL_PATH/srv/modelLocal directory for cached artifacts.
FLYTE_MODEL_LOADER_CHUNK_SIZE16777216 (16MB)Size of each parallel download chunk.
FLYTE_MODEL_LOADER_MAX_CONCURRENCY32Maximum concurrent range requests.
FLYTE_MODEL_LOADER_STREAM_SAFETENSORSfalseWhether to enable streaming for weights.

Requirements and Integration

The high-performance loader requires torch and obstore to be installed. It is the primary mechanism used by the flyteplugins-vllm and flyteplugins-sglang packages to accelerate model loading in Flyte-managed inference services.