High-Performance Model Loading
When deploying large machine learning models (LLMs) in serving environments like vLLM or SGLang, the time spent downloading multi-gigabyte weight files from object storage to local disk can significantly delay container startup. The flyte-sdk provides a high-performance model loading utility centered around the SafeTensorsStreamer class, which streams weights directly from S3 or GCS into memory, bypassing the need for full local file downloads.
Streaming Weights with SafeTensorsStreamer
The SafeTensorsStreamer is designed to iterate over model weights and load them into PyTorch tensors as they are downloaded. This is particularly useful in custom model loaders where you want to populate a model's state dict without waiting for the entire model to be present on disk.
To use the streamer, initialize it with the remote and local paths and iterate over the get_tensors() generator:
from flyte.app.extras._model_loader.loader import SafeTensorsStreamer
# Initialize the streamer
streamer = SafeTensorsStreamer(
remote_path="s3://my-bucket/models/llama-7b/",
local_path="/srv/model"
)
# Stream tensors directly into memory
for name, tensor in streamer.get_tensors():
# 'name' is the tensor key (e.g., 'model.layers.0.self_attn.q_proj.weight')
# 'tensor' is a torch.Tensor ready for use
print(f"Loaded {name} with shape {tensor.shape}")
Internally, SafeTensorsStreamer uses asyncio.Runner to wrap its asynchronous implementation, providing a clean synchronous generator interface for standard Python code.
Implementation Details
The efficiency of the streamer comes from its ability to perform parallel range requests and parse the SafeTensors format on the fly.
Header Parsing and Metadata
Before downloading weights, the streamer must know where each tensor starts and ends within the file. It performs two targeted range requests using obstore.get_range_async:
- It reads the first 8 bytes to determine the
header_size. - It reads the next
header_sizebytes to parse the JSON metadata.
This metadata is encapsulated in SafeTensorsMetadata and TensorMetadata classes (defined in flyte.app.extras._model_loader.loader). These classes store the shape, dtype, and data_offsets for every tensor in the file.
Parallel Chunked Downloads
The streamer utilizes ObstoreParallelReader (from flyte.storage._parallel_reader) to manage the download process. It breaks down each tensor into smaller chunks (defined by CHUNK_SIZE) and fetches them concurrently using a pool of workers.
Zero-Copy Loading
Once all chunks for a specific tensor are downloaded into a memory buffer, the streamer uses torch.frombuffer to create a tensor without copying the underlying data:
# Internal logic in SafeTensorsStreamer._get_tensors_async
return torch.frombuffer(
await buf.read(),
dtype=source.metadata.dtype,
count=len(source.metadata),
offset=0,
).view(source.metadata.shape)
Handling Sharded Models and Tensor Parallelism
Large models are often sharded across multiple files. SafeTensorsStreamer supports two methods for locating these shards:
- Index-based: If a
model.safetensors.index.jsonfile exists in the remote path, the streamer parses it to map tensor names to specific files. - Pattern-based: If no index is found, or if
tensor_parallel_size > 1is specified, it falls back to pattern matching.
When using tensor parallelism, you can specify the rank and tensor_parallel_size to load only the shards relevant to a specific GPU:
streamer = SafeTensorsStreamer(
remote_path=REMOTE_PATH,
local_path=LOCAL_PATH,
rank=0,
tensor_parallel_size=2
)
In this mode, the streamer looks for files matching the pattern model-rank-{rank}-part-*.safetensors.
Prefetching and Configuration
While weights are streamed, other artifacts like config.json or tokenizer.json are still needed on disk by most model loaders. The prefetch function in flyte.app.extras._model_loader.loader handles downloading these non-weight files while optionally excluding the large .safetensors files.
The behavior of the model loader is controlled by several environment variables:
| Variable | Default | Description |
|---|---|---|
FLYTE_MODEL_LOADER_REMOTE_MODEL_PATH | None | The S3/GCS path to the model. |
FLYTE_MODEL_LOADER_LOCAL_MODEL_PATH | /srv/model | Local directory for cached artifacts. |
FLYTE_MODEL_LOADER_CHUNK_SIZE | 16777216 (16MB) | Size of each parallel download chunk. |
FLYTE_MODEL_LOADER_MAX_CONCURRENCY | 32 | Maximum concurrent range requests. |
FLYTE_MODEL_LOADER_STREAM_SAFETENSORS | false | Whether to enable streaming for weights. |
Requirements and Integration
The high-performance loader requires torch and obstore to be installed. It is the primary mechanism used by the flyteplugins-vllm and flyteplugins-sglang packages to accelerate model loading in Flyte-managed inference services.