Managing Large Language Models with Prefetch
The flyte-sdk prefetch system allows you to download, shard, and store Large Language Models (LLMs) from the HuggingFace Hub directly into your remote storage. This process reduces application startup time by avoiding repeated downloads and enables high-performance inference by pre-sharding models for multi-GPU deployments.
Prefetching a Model
To prefetch a model, use the hf_model function. This function triggers a Flyte task that handles the download and storage of the model artifacts.
import flyte.prefetch
# Initialize flyte context
flyte.init()
# Prefetch a model to remote storage
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-7b-hf",
hf_token_key="HF_TOKEN", # Name of the Flyte secret containing your HF token
)
# Wait for the prefetch task to complete
run.wait()
print(f"Model stored at: {run.outputs()[0].path}")
The hf_model function returns a flyte.remote.Run object, which you can use to track the progress of the prefetch task.
Configuring Model Sharding with vLLM
For large models that require multiple GPUs, you can configure the prefetch system to shard the model using the vLLM engine. This is done by providing a ShardConfig with VLLMShardArgs.
from flyte.prefetch import ShardConfig, VLLMShardArgs
import flyte
# Configure sharding for a 70B model across 8 GPUs
shard_config = ShardConfig(
engine="vllm",
args=VLLMShardArgs(
tensor_parallel_size=8,
dtype="auto",
trust_remote_code=True
)
)
# Run prefetch with sharding and appropriate GPU resources
run = flyte.prefetch.hf_model(
repo="meta-llama/Llama-2-70b-hf",
shard_config=shard_config,
resources=flyte.Resources(
cpu="32",
memory="256Gi",
gpu="A100:8",
disk="500Gi"
)
)
run.wait()
Key Sharding Parameters
The VLLMShardArgs class (defined in src/flyte/prefetch/_hf_model.py) supports several parameters to control the sharding process:
tensor_parallel_size: The number of GPUs to shard the model across.dtype: Data type for model weights (e.g.,"auto","float16","bfloat16").max_model_len: Maximum model context length to use during sharding.max_file_size: Maximum size for each sharded file (default is 5GB).
Using Prefetched Models in Applications
Once a model is prefetched, you can use its output in a flyte.serve application by referencing the run name. This ensures your application uses the exact artifacts generated by the prefetch task.
import flyte.app
from flyteplugins.vllm import VLLMAppEnvironment
# Define your application environment
vllm_app = VLLMAppEnvironment(
name="llama-service",
model_hf_path="meta-llama/Llama-2-7b-hf",
resources=flyte.Resources(gpu="L4:1"),
)
# Deploy the app using the output from the prefetch run
app = flyte.serve(
vllm_app.clone_with(
# Link the model_path to the output of the prefetch run
model_path=flyte.app.RunOutput(type="directory", run_name=run.name),
model_hf_path=None,
)
)
Configuration and Secrets
HuggingFace Token
The prefetch task requires a HuggingFace API token to access private or gated repositories. This token must be stored as a Flyte Secret. By default, hf_model looks for a secret named HF_TOKEN. You can customize this using the hf_token_key parameter.
Resource Allocation
Prefetching large models is resource-intensive. Ensure the resources parameter in hf_model provides enough disk space and memory:
- Disk: Should be at least 2x the model size to accommodate the download and final storage.
- Memory: For sharding, memory should be sufficient to load the model weights.
- GPU: Required only if
shard_configis provided. Thegpustring should match the format{type}:{quantity}(e.g.,A100:8).
Troubleshooting
Artifact Name Restrictions
The artifact_name (if provided) must only contain alphanumeric characters, underscores, and hyphens. Dots and slashes are automatically replaced in the default name (derived from the repo ID) but will cause a ValueError if manually provided incorrectly.
Streaming vs. Snapshot
The flyte-sdk prefetch system attempts to stream files directly from HuggingFace to remote storage to save local disk space. If streaming fails, it falls back to a full snapshot download. If sharding is enabled, it always performs a local download first to allow the vLLM engine to process the weights.
Missing HF_TOKEN
If the HF_TOKEN environment variable is not found during the task execution, the prefetch will fail with an assertion error. Ensure the secret is correctly configured in your Flyte deployment and the key name matches what is passed to hf_token_key.