Skip to main content

Dynamic Batching for Inference

When running high-throughput inference workloads, such as LLM generation or image classification, processing requests one-by-one often leaves GPUs underutilized and increases latency due to overhead. flyte-sdk provides the DynamicBatcher and TokenBatcher to aggregate concurrent requests from multiple producers into optimal batches, maximizing hardware saturation while respecting cost and time constraints.

Basic Batching with TokenBatcher

For LLM workloads, the TokenBatcher is the recommended entry point. It is a specialized version of DynamicBatcher that uses token-aware terminology and supports token estimation protocols.

You can use the batcher as an asynchronous context manager to automatically handle the lifecycle of its internal processing loops.

import asyncio
from flyte.extras import TokenBatcher, Prompt

async def my_inference_fn(batch: list[Prompt]) -> list[str]:
# Simulate an LLM call that processes a batch
return [f"Generated text for: {p.text}" for p in batch]

async def main():
async with TokenBatcher(inference_fn=my_inference_fn) as batcher:
# Submit a record and get a future
future = await batcher.submit(Prompt(text="Hello, Flyte!"))

# Wait for the result
result = await future
print(result)

asyncio.run(main())

Persistent Workers and GPU Saturation

In a production Flyte environment, you typically want to share a single batcher instance across many concurrent task executions on the same worker. This is achieved by combining TokenBatcher with alru_cache and Flyte's ReusePolicy.

This pattern allows multiple Flyte tasks running on the same pod to "funnel" their individual requests into a single high-throughput stream.

from async_lru import alru_cache
from flytekit import task
from flyte.extras import TokenBatcher, Prompt

@alru_cache(maxsize=1)
async def get_batcher() -> TokenBatcher[Prompt, str]:
# The batcher is initialized once and shared across task invocations
batcher = TokenBatcher[Prompt, str](
inference_fn=my_inference_fn,
target_batch_tokens=32_000,
max_batch_size=256,
batch_timeout_s=0.05,
)
await batcher.start()
return batcher

@task
async def infer_task(prompts: list[str]) -> list[str]:
batcher = await get_batcher()
futures = []
for text in prompts:
# submit() provides backpressure by awaiting if the queue is full
future = await batcher.submit(Prompt(text=text))
futures.append(future)

# Gather results for this specific task invocation
return await asyncio.gather(*futures)

How it Works Internally

The DynamicBatcher (and by extension TokenBatcher) manages two internal loops to decouple request submission from execution:

  1. Aggregation Loop: This loop drains the submission queue and assembles batches. It waits until target_batch_cost (or target_batch_tokens) is reached, or until batch_timeout_s expires. If the timeout expires and the batch is smaller than min_batch_size, it re-enqueues the records to wait for more data.
  2. Processing Loop: This loop pulls assembled batches from a prefetch buffer and executes the process_fn. It ensures that each record's asyncio.Future is resolved with the corresponding result from the returned list.

The Processing Contract

The process_fn (or inference_fn) you provide must return a list of results that matches the exact length and order of the input batch. If the lengths do not match, DynamicBatcher raises a ValueError and fails the entire batch.

If the process_fn raises an exception, that exception is propagated to every Future in that specific batch.

Cost and Token Estimation

To prevent batches from exceeding hardware memory limits (e.g., GPU VRAM), the batcher uses a cost-budgeting system. You can define how "expensive" a record is using the CostEstimator or TokenEstimator protocols.

Implementing Protocols

If your record type implements estimate_cost() or estimate_tokens(), the batcher will call it automatically.

from dataclasses import dataclass

@dataclass
class CustomRequest:
data: bytes

def estimate_cost(self) -> int:
# Use the size of the data as the cost
return len(self.data)

# TokenBatcher will use estimate_cost() if estimate_tokens() is missing

The flyte-sdk includes a Prompt class in flyte.extras that provides a default heuristic for LLMs: len(text) // 4 + 1.

Explicit Overrides

You can also provide an explicit estimate at submission time:

# Overriding the cost for a specific record
future = await batcher.submit(record, estimated_cost=512)

# Or in TokenBatcher
future = await batcher.submit(record, estimated_tokens=128)

Monitoring with BatchStats

You can monitor the efficiency of your batching strategy via the stats property on the batcher. This returns a BatchStats object containing real-time metrics.

stats = batcher.stats
print(f"Total Batches: {stats.total_batches}")
print(f"Avg Batch Size: {stats.avg_batch_size}")
print(f"GPU Utilization: {stats.utilization * 100:.1f}%")

Key metrics in BatchStats include:

  • utilization: The fraction of wall-clock time spent inside the process_fn versus waiting for batches to assemble.
  • busy_time_s and idle_time_s: Cumulative seconds spent processing vs. idling.
  • total_batch_cost: The sum of all estimated costs processed so far.

Configuration Tuning

The behavior of the batcher is controlled by several parameters in the DynamicBatcher constructor:

ParameterDefaultDescription
target_batch_cost32,000The aggregator attempts to fill batches up to this limit.
max_batch_size256Hard cap on the number of records per batch.
batch_timeout_s0.05Maximum time to wait for a batch to fill before dispatching.
max_queue_size5,000Bounded size for the submission queue. Triggers backpressure when full.
prefetch_batches2Number of batches to buffer between the aggregator and processor.

Lowering batch_timeout_s reduces latency for individual requests but may result in smaller, less efficient batches if the arrival rate is low. Increasing prefetch_batches can help hide the latency of the process_fn by ensuring the next batch is ready as soon as the current one finishes.